Ver código fonte

Merge pull request #1342 from emqtt/race-condition

Improve the pubsub design and fix the race-condition issue
Feng Lee 8 anos atrás
pai
commit
8e4f109c9e

+ 2 - 2
Makefile

@@ -1,6 +1,6 @@
 PROJECT = emqttd
 PROJECT = emqttd
 PROJECT_DESCRIPTION = Erlang MQTT Broker
 PROJECT_DESCRIPTION = Erlang MQTT Broker
-PROJECT_VERSION = 2.3
+PROJECT_VERSION = 2.3.0
 
 
 DEPS = goldrush gproc lager esockd ekka mochiweb pbkdf2 lager_syslog bcrypt clique jsx
 DEPS = goldrush gproc lager esockd ekka mochiweb pbkdf2 lager_syslog bcrypt clique jsx
 
 
@@ -36,7 +36,7 @@ EUNIT_OPTS = verbose
 
 
 CT_SUITES = emqttd emqttd_access emqttd_lib emqttd_inflight emqttd_mod \
 CT_SUITES = emqttd emqttd_access emqttd_lib emqttd_inflight emqttd_mod \
             emqttd_net emqttd_mqueue emqttd_protocol emqttd_topic \
             emqttd_net emqttd_mqueue emqttd_protocol emqttd_topic \
-            emqttd_trie emqttd_vm emqttd_config
+            emqttd_router emqttd_trie emqttd_vm emqttd_config
 
 
 CT_OPTS = -cover test/ct.cover.spec -erl_args -name emqttd_ct@127.0.0.1
 CT_OPTS = -cover test/ct.cover.spec -erl_args -name emqttd_ct@127.0.0.1
 
 

+ 1 - 3
etc/emq.conf

@@ -1,5 +1,5 @@
 ##===================================================================
 ##===================================================================
-## EMQ Configuration R2.3
+## EMQ Configuration R2.3.0
 ##===================================================================
 ##===================================================================
 
 
 ##--------------------------------------------------------------------
 ##--------------------------------------------------------------------
@@ -288,8 +288,6 @@ mqtt.broker.sys_interval = 60
 ## PubSub Pool Size. Default should be scheduler numbers.
 ## PubSub Pool Size. Default should be scheduler numbers.
 mqtt.pubsub.pool_size = 8
 mqtt.pubsub.pool_size = 8
 
 
-mqtt.pubsub.by_clientid = true
-
 ## Subscribe Asynchronously
 ## Subscribe Asynchronously
 mqtt.pubsub.async = true
 mqtt.pubsub.async = true
 
 

+ 4 - 4
include/emqttd_trie.hrl

@@ -17,10 +17,10 @@
 -type(trie_node_id() :: binary() | atom()).
 -type(trie_node_id() :: binary() | atom()).
 
 
 -record(trie_node,
 -record(trie_node,
-        { node_id         :: trie_node_id(),
-          edge_count = 0  :: non_neg_integer(),
-          topic           :: binary() | undefined,
-          flags           :: [retained | static]
+        { node_id        :: trie_node_id(),
+          edge_count = 0 :: non_neg_integer(),
+          topic          :: binary() | undefined,
+          flags          :: [retained | static]
         }).
         }).
 
 
 -record(trie_edge,
 -record(trie_edge,

+ 0 - 5
priv/emq.schema

@@ -715,11 +715,6 @@ end}.
   {datatype, integer}
   {datatype, integer}
 ]}.
 ]}.
 
 
-{mapping, "mqtt.pubsub.by_clientid", "emqttd.pubsub", [
-  {default, true},
-  {datatype, {enum, [true, false]}}
-]}.
-
 {mapping, "mqtt.pubsub.async", "emqttd.pubsub", [
 {mapping, "mqtt.pubsub.async", "emqttd.pubsub", [
   {default, true},
   {default, true},
   {datatype, {enum, [true, false]}}
   {datatype, {enum, [true, false]}}

+ 1 - 1
src/emqttd.app.src

@@ -1,6 +1,6 @@
 {application,emqttd,
 {application,emqttd,
              [{description,"Erlang MQTT Broker"},
              [{description,"Erlang MQTT Broker"},
-              {vsn,"2.3"},
+              {vsn,"2.3.0"},
               {modules,[]},
               {modules,[]},
               {registered,[emqttd_sup]},
               {registered,[emqttd_sup]},
               {applications,[kernel,stdlib,gproc,lager,esockd,mochiweb,
               {applications,[kernel,stdlib,gproc,lager,esockd,mochiweb,

+ 24 - 30
src/emqttd.erl

@@ -31,8 +31,7 @@
          unsubscribe/1, unsubscribe/2]).
          unsubscribe/1, unsubscribe/2]).
 
 
 %% PubSub Management API
 %% PubSub Management API
--export([setqos/3, topics/0, subscriptions/1, subscribers/1,
-         is_subscribed/2, subscriber_down/1]).
+-export([setqos/3, topics/0, subscriptions/1, subscribers/1, subscribed/2]).
 
 
 %% Hooks API
 %% Hooks API
 -export([hook/4, hook/3, unhook/2, run_hooks/2, run_hooks/3]).
 -export([hook/4, hook/3, unhook/2, run_hooks/2, run_hooks/3]).
@@ -43,14 +42,13 @@
 %% Shutdown and reboot
 %% Shutdown and reboot
 -export([shutdown/0, shutdown/1, reboot/0]).
 -export([shutdown/0, shutdown/1, reboot/0]).
 
 
--type(subscriber() :: pid() | binary()).
+-type(subid() :: binary()).
 
 
--type(suboption() :: local | {qos, non_neg_integer()} | {share, {'$queue' | binary()}}).
+-type(subscriber() :: pid() | subid() | {subid(), pid()}).
 
 
--type(pubsub_error() :: {error, {already_subscribed, binary()}
-                              | {subscription_not_found, binary()}}).
+-type(suboption() :: local | {qos, non_neg_integer()} | {share, {'$queue' | binary()}}).
 
 
--export_type([subscriber/0, suboption/0, pubsub_error/0]).
+-export_type([subscriber/0, suboption/0]).
 
 
 -define(APP, ?MODULE).
 -define(APP, ?MODULE).
 
 
@@ -59,19 +57,19 @@
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
 %% @doc Start emqttd application.
 %% @doc Start emqttd application.
--spec(start() -> ok | {error, any()}).
+-spec(start() -> ok | {error, term()}).
 start() -> application:start(?APP).
 start() -> application:start(?APP).
 
 
 %% @doc Stop emqttd application.
 %% @doc Stop emqttd application.
--spec(stop() -> ok | {error, any()}).
+-spec(stop() -> ok | {error, term()}).
 stop() -> application:stop(?APP).
 stop() -> application:stop(?APP).
 
 
 %% @doc Environment
 %% @doc Environment
--spec(env(Key:: atom()) -> {ok, any()} | undefined).
+-spec(env(Key :: atom()) -> {ok, any()} | undefined).
 env(Key) -> application:get_env(?APP, Key).
 env(Key) -> application:get_env(?APP, Key).
 
 
 %% @doc Get environment
 %% @doc Get environment
--spec(env(Key:: atom(), Default:: any()) -> undefined | any()).
+-spec(env(Key :: atom(), Default :: any()) -> undefined | any()).
 env(Key, Default) -> application:get_env(?APP, Key, Default).
 env(Key, Default) -> application:get_env(?APP, Key, Default).
 
 
 %% @doc Is running?
 %% @doc Is running?
@@ -88,15 +86,15 @@ is_running(Node) ->
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
 %% @doc Subscribe
 %% @doc Subscribe
--spec(subscribe(iodata()) -> ok | {error, any()}).
+-spec(subscribe(iodata()) -> ok | {error, term()}).
 subscribe(Topic) ->
 subscribe(Topic) ->
-    subscribe(Topic, self()).
+    emqttd_server:subscribe(iolist_to_binary(Topic)).
 
 
--spec(subscribe(iodata(), subscriber()) -> ok | {error, any()}).
+-spec(subscribe(iodata(), subscriber()) -> ok | {error, term()}).
 subscribe(Topic, Subscriber) ->
 subscribe(Topic, Subscriber) ->
-    subscribe(Topic, Subscriber, []).
+    emqttd_server:subscribe(iolist_to_binary(Topic), Subscriber).
 
 
--spec(subscribe(iodata(), subscriber(), [suboption()]) -> ok | pubsub_error()).
+-spec(subscribe(iodata(), subscriber(), [suboption()]) -> ok | {error, term()}).
 subscribe(Topic, Subscriber, Options) ->
 subscribe(Topic, Subscriber, Options) ->
     emqttd_server:subscribe(iolist_to_binary(Topic), Subscriber, Options).
     emqttd_server:subscribe(iolist_to_binary(Topic), Subscriber, Options).
 
 
@@ -106,11 +104,11 @@ publish(Msg) ->
     emqttd_server:publish(Msg).
     emqttd_server:publish(Msg).
 
 
 %% @doc Unsubscribe
 %% @doc Unsubscribe
--spec(unsubscribe(iodata()) -> ok | pubsub_error()).
+-spec(unsubscribe(iodata()) -> ok | {error, term()}).
 unsubscribe(Topic) ->
 unsubscribe(Topic) ->
-    unsubscribe(Topic, self()).
+    emqttd_server:unsubscribe(iolist_to_binary(Topic)).
 
 
--spec(unsubscribe(iodata(), subscriber()) -> ok | pubsub_error()).
+-spec(unsubscribe(iodata(), subscriber()) -> ok | {error, term()}).
 unsubscribe(Topic, Subscriber) ->
 unsubscribe(Topic, Subscriber) ->
     emqttd_server:unsubscribe(iolist_to_binary(Topic), Subscriber).
     emqttd_server:unsubscribe(iolist_to_binary(Topic), Subscriber).
 
 
@@ -125,34 +123,30 @@ topics() -> emqttd_router:topics().
 subscribers(Topic) ->
 subscribers(Topic) ->
     emqttd_server:subscribers(iolist_to_binary(Topic)).
     emqttd_server:subscribers(iolist_to_binary(Topic)).
 
 
--spec(subscriptions(subscriber()) -> [{binary(), binary(), list(suboption())}]).
+-spec(subscriptions(subscriber()) -> [{emqttd:subscriber(), binary(), list(emqttd:suboption())}]).
 subscriptions(Subscriber) ->
 subscriptions(Subscriber) ->
     emqttd_server:subscriptions(Subscriber).
     emqttd_server:subscriptions(Subscriber).
 
 
--spec(is_subscribed(iodata(), subscriber()) -> boolean()).
-is_subscribed(Topic, Subscriber) ->
-    emqttd_server:is_subscribed(iolist_to_binary(Topic), Subscriber).
-
--spec(subscriber_down(subscriber()) -> ok).
-subscriber_down(Subscriber) ->
-    emqttd_server:subscriber_down(Subscriber).
+-spec(subscribed(iodata(), subscriber()) -> boolean()).
+subscribed(Topic, Subscriber) ->
+    emqttd_server:subscribed(iolist_to_binary(Topic), Subscriber).
 
 
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 %% Hooks API
 %% Hooks API
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
 -spec(hook(atom(), function() | {emqttd_hooks:hooktag(), function()}, list(any()))
 -spec(hook(atom(), function() | {emqttd_hooks:hooktag(), function()}, list(any()))
-      -> ok | {error, any()}).
+      -> ok | {error, term()}).
 hook(Hook, TagFunction, InitArgs) ->
 hook(Hook, TagFunction, InitArgs) ->
     emqttd_hooks:add(Hook, TagFunction, InitArgs).
     emqttd_hooks:add(Hook, TagFunction, InitArgs).
 
 
 -spec(hook(atom(), function() | {emqttd_hooks:hooktag(), function()}, list(any()), integer())
 -spec(hook(atom(), function() | {emqttd_hooks:hooktag(), function()}, list(any()), integer())
-      -> ok | {error, any()}).
+      -> ok | {error, term()}).
 hook(Hook, TagFunction, InitArgs, Priority) ->
 hook(Hook, TagFunction, InitArgs, Priority) ->
     emqttd_hooks:add(Hook, TagFunction, InitArgs, Priority).
     emqttd_hooks:add(Hook, TagFunction, InitArgs, Priority).
 
 
 -spec(unhook(atom(), function() | {emqttd_hooks:hooktag(), function()})
 -spec(unhook(atom(), function() | {emqttd_hooks:hooktag(), function()})
-      -> ok | {error, any()}).
+      -> ok | {error, term()}).
 unhook(Hook, TagFunction) ->
 unhook(Hook, TagFunction) ->
     emqttd_hooks:delete(Hook, TagFunction).
     emqttd_hooks:delete(Hook, TagFunction).
 
 

+ 5 - 5
src/emqttd_access_control.erl

@@ -43,12 +43,12 @@
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
 %% @doc Start access control server.
 %% @doc Start access control server.
--spec(start_link() -> {ok, pid()} | ignore | {error, any()}).
+-spec(start_link() -> {ok, pid()} | ignore | {error, term()}).
 start_link() ->
 start_link() ->
     gen_server:start_link({local, ?SERVER}, ?MODULE, [], []).
     gen_server:start_link({local, ?SERVER}, ?MODULE, [], []).
 
 
 %% @doc Authenticate MQTT Client.
 %% @doc Authenticate MQTT Client.
--spec(auth(Client :: mqtt_client(), Password :: password()) -> ok | {ok, boolean()} | {error, any()}).
+-spec(auth(Client :: mqtt_client(), Password :: password()) -> ok | {ok, boolean()} | {error, term()}).
 auth(Client, Password) when is_record(Client, mqtt_client) ->
 auth(Client, Password) when is_record(Client, mqtt_client) ->
     auth(Client, Password, lookup_mods(auth)).
     auth(Client, Password, lookup_mods(auth)).
 auth(_Client, _Password, []) ->
 auth(_Client, _Password, []) ->
@@ -88,16 +88,16 @@ reload_acl() ->
     [Mod:reload_acl(State) || {Mod, State, _Seq} <- lookup_mods(acl)].
     [Mod:reload_acl(State) || {Mod, State, _Seq} <- lookup_mods(acl)].
 
 
 %% @doc Register Authentication or ACL module.
 %% @doc Register Authentication or ACL module.
--spec(register_mod(auth | acl, atom(), list()) -> ok | {error, any()}).
+-spec(register_mod(auth | acl, atom(), list()) -> ok | {error, term()}).
 register_mod(Type, Mod, Opts) when Type =:= auth; Type =:= acl->
 register_mod(Type, Mod, Opts) when Type =:= auth; Type =:= acl->
     register_mod(Type, Mod, Opts, 0).
     register_mod(Type, Mod, Opts, 0).
 
 
--spec(register_mod(auth | acl, atom(), list(), non_neg_integer()) -> ok | {error, any()}).
+-spec(register_mod(auth | acl, atom(), list(), non_neg_integer()) -> ok | {error, term()}).
 register_mod(Type, Mod, Opts, Seq) when Type =:= auth; Type =:= acl->
 register_mod(Type, Mod, Opts, Seq) when Type =:= auth; Type =:= acl->
     gen_server:call(?SERVER, {register_mod, Type, Mod, Opts, Seq}).
     gen_server:call(?SERVER, {register_mod, Type, Mod, Opts, Seq}).
 
 
 %% @doc Unregister authentication or ACL module
 %% @doc Unregister authentication or ACL module
--spec(unregister_mod(Type :: auth | acl, Mod :: atom()) -> ok | {error, any()}).
+-spec(unregister_mod(Type :: auth | acl, Mod :: atom()) -> ok | {error, not_found | term()}).
 unregister_mod(Type, Mod) when Type =:= auth; Type =:= acl ->
 unregister_mod(Type, Mod) when Type =:= auth; Type =:= acl ->
     gen_server:call(?SERVER, {unregister_mod, Type, Mod}).
     gen_server:call(?SERVER, {unregister_mod, Type, Mod}).
 
 

+ 1 - 1
src/emqttd_acl_mod.erl

@@ -32,7 +32,7 @@
                      PubSub :: pubsub(),
                      PubSub :: pubsub(),
                      Topic  :: binary()}, State :: any()) -> allow | deny | ignore).
                      Topic  :: binary()}, State :: any()) -> allow | deny | ignore).
 
 
--callback(reload_acl(State :: any()) -> ok | {error, any()}).
+-callback(reload_acl(State :: any()) -> ok | {error, term()}).
 
 
 -callback(description() -> string()).
 -callback(description() -> string()).
 
 

+ 1 - 1
src/emqttd_bridge_sup.erl

@@ -23,7 +23,7 @@
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
 %% @doc Start bridge pool supervisor
 %% @doc Start bridge pool supervisor
--spec(start_link(atom(), binary(), [emqttd_bridge:option()]) -> {ok, pid()} | {error, any()}).
+-spec(start_link(atom(), binary(), [emqttd_bridge:option()]) -> {ok, pid()} | {error, term()}).
 start_link(Node, Topic, Options) ->
 start_link(Node, Topic, Options) ->
     MFA = {emqttd_bridge, start_link, [Node, Topic, Options]},
     MFA = {emqttd_bridge, start_link, [Node, Topic, Options]},
     emqttd_pool_sup:start_link({bridge, Node, Topic}, random, MFA).
     emqttd_pool_sup:start_link({bridge, Node, Topic}, random, MFA).

+ 2 - 2
src/emqttd_bridge_sup_sup.erl

@@ -40,11 +40,11 @@ bridges() ->
                              <- supervisor:which_children(?MODULE)].
                              <- supervisor:which_children(?MODULE)].
 
 
 %% @doc Start a bridge
 %% @doc Start a bridge
--spec(start_bridge(atom(), binary()) -> {ok, pid()} | {error, any()}).
+-spec(start_bridge(atom(), binary()) -> {ok, pid()} | {error, term()}).
 start_bridge(Node, Topic) when is_atom(Node) andalso is_binary(Topic) ->
 start_bridge(Node, Topic) when is_atom(Node) andalso is_binary(Topic) ->
     start_bridge(Node, Topic, []).
     start_bridge(Node, Topic, []).
 
 
--spec(start_bridge(atom(), binary(), [emqttd_bridge:option()]) -> {ok, pid()} | {error, any()}).
+-spec(start_bridge(atom(), binary(), [emqttd_bridge:option()]) -> {ok, pid()} | {error, term()}).
 start_bridge(Node, _Topic, _Options) when Node =:= node() ->
 start_bridge(Node, _Topic, _Options) when Node =:= node() ->
     {error, bridge_to_self};
     {error, bridge_to_self};
 start_bridge(Node, Topic, Options) when is_atom(Node) andalso is_binary(Topic) ->
 start_bridge(Node, Topic, Options) when is_atom(Node) andalso is_binary(Topic) ->

+ 1 - 1
src/emqttd_broker.erl

@@ -61,7 +61,7 @@
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
 %% @doc Start emqttd broker
 %% @doc Start emqttd broker
--spec(start_link() -> {ok, pid()} | ignore | {error, any()}).
+-spec(start_link() -> {ok, pid()} | ignore | {error, term()}).
 start_link() ->
 start_link() ->
     gen_server:start_link({local, ?SERVER}, ?MODULE, [], []).
     gen_server:start_link({local, ?SERVER}, ?MODULE, [], []).
 
 

+ 0 - 9
src/emqttd_cli.erl

@@ -248,7 +248,6 @@ subscriptions(["show", ClientId]) ->
         Records -> [print(subscription, Subscription) || Subscription <- Records]
         Records -> [print(subscription, Subscription) || Subscription <- Records]
     end;
     end;
 
 
-
 subscriptions(["add", ClientId, Topic, QoS]) ->
 subscriptions(["add", ClientId, Topic, QoS]) ->
    Add = fun(IntQos) ->
    Add = fun(IntQos) ->
            case emqttd:subscribe(bin(Topic), bin(ClientId), [{qos, IntQos}]) of
            case emqttd:subscribe(bin(Topic), bin(ClientId), [{qos, IntQos}]) of
@@ -260,22 +259,14 @@ subscriptions(["add", ClientId, Topic, QoS]) ->
          end,
          end,
    if_valid_qos(QoS, Add);
    if_valid_qos(QoS, Add);
 
 
-
-
-subscriptions(["del", ClientId]) ->
-   Ok = emqttd:subscriber_down(bin(ClientId)),
-   ?PRINT("~p~n", [Ok]);
-
 subscriptions(["del", ClientId, Topic]) ->
 subscriptions(["del", ClientId, Topic]) ->
    Ok = emqttd:unsubscribe(bin(Topic), bin(ClientId)),
    Ok = emqttd:unsubscribe(bin(Topic), bin(ClientId)),
    ?PRINT("~p~n", [Ok]);
    ?PRINT("~p~n", [Ok]);
 
 
-
 subscriptions(_) ->
 subscriptions(_) ->
     ?USAGE([{"subscriptions list",                         "List all subscriptions"},
     ?USAGE([{"subscriptions list",                         "List all subscriptions"},
             {"subscriptions show <ClientId>",              "Show subscriptions of a client"},
             {"subscriptions show <ClientId>",              "Show subscriptions of a client"},
             {"subscriptions add <ClientId> <Topic> <QoS>", "Add a static subscription manually"},
             {"subscriptions add <ClientId> <Topic> <QoS>", "Add a static subscription manually"},
-            {"subscriptions del <ClientId>",               "Delete static subscriptions manually"},
             {"subscriptions del <ClientId> <Topic>",       "Delete a static subscription manually"}]).
             {"subscriptions del <ClientId> <Topic>",       "Delete a static subscription manually"}]).
 
 
 % if_could_print(Tab, Fun) ->
 % if_could_print(Tab, Fun) ->

+ 1 - 1
src/emqttd_cm.erl

@@ -47,7 +47,7 @@
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
 %% @doc Start Client Manager
 %% @doc Start Client Manager
--spec(start_link(atom(), pos_integer(), fun()) -> {ok, pid()} | ignore | {error, any()}).
+-spec(start_link(atom(), pos_integer(), fun()) -> {ok, pid()} | ignore | {error, term()}).
 start_link(Pool, Id, StatsFun) ->
 start_link(Pool, Id, StatsFun) ->
     gen_server2:start_link(?MODULE, [Pool, Id, StatsFun], []).
     gen_server2:start_link(?MODULE, [Pool, Id, StatsFun], []).
 
 

+ 2 - 2
src/emqttd_gen_mod.erl

@@ -24,9 +24,9 @@
 
 
 -ifdef(use_specs).
 -ifdef(use_specs).
 
 
--callback(load(Opts :: any()) -> ok | {error, any()}).
+-callback(load(Opts :: any()) -> ok | {error, term()}).
 
 
--callback(unload(State :: any()) -> any()).
+-callback(unload(State :: term()) -> term()).
 
 
 -else.
 -else.
 
 

+ 2 - 2
src/emqttd_keepalive.erl

@@ -29,7 +29,7 @@
 -export_type([keepalive/0]).
 -export_type([keepalive/0]).
 
 
 %% @doc Start a keepalive
 %% @doc Start a keepalive
--spec(start(fun(), integer(), any()) -> {ok, keepalive()} | {error, any()}).
+-spec(start(fun(), integer(), any()) -> {ok, keepalive()} | {error, term()}).
 start(_, 0, _) ->
 start(_, 0, _) ->
     {ok, #keepalive{}};
     {ok, #keepalive{}};
 start(StatFun, TimeoutSec, TimeoutMsg) ->
 start(StatFun, TimeoutSec, TimeoutMsg) ->
@@ -43,7 +43,7 @@ start(StatFun, TimeoutSec, TimeoutMsg) ->
     end.
     end.
 
 
 %% @doc Check keepalive, called when timeout.
 %% @doc Check keepalive, called when timeout.
--spec(check(keepalive()) -> {ok, keepalive()} | {error, any()}).
+-spec(check(keepalive()) -> {ok, keepalive()} | {error, term()}).
 check(KeepAlive = #keepalive{statfun = StatFun, statval = LastVal, repeat = Repeat}) ->
 check(KeepAlive = #keepalive{statfun = StatFun, statval = LastVal, repeat = Repeat}) ->
     case StatFun() of
     case StatFun() of
         {ok, NewVal} ->
         {ok, NewVal} ->

+ 2 - 2
src/emqttd_metrics.erl

@@ -96,8 +96,8 @@
 %% API
 %% API
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
-%% @doc Start metrics server
--spec(start_link() -> {ok, pid()} | ignore | {error, any()}).
+%% @doc Start the metrics server
+-spec(start_link() -> {ok, pid()} | ignore | {error, term()}).
 start_link() ->
 start_link() ->
     gen_server:start_link({local, ?SERVER}, ?MODULE, [], []).
     gen_server:start_link({local, ?SERVER}, ?MODULE, [], []).
 
 

+ 1 - 1
src/emqttd_mod_sup.erl

@@ -42,7 +42,7 @@ start_child(ChildSpec) when is_tuple(ChildSpec) ->
 start_child(Mod, Type) when is_atom(Mod) andalso is_atom(Type) ->
 start_child(Mod, Type) when is_atom(Mod) andalso is_atom(Type) ->
     supervisor:start_child(?MODULE, ?CHILD(Mod, Type)).
     supervisor:start_child(?MODULE, ?CHILD(Mod, Type)).
 
 
--spec(stop_child(any()) -> ok | {error, any()}).
+-spec(stop_child(any()) -> ok | {error, term()}).
 stop_child(ChildId) ->
 stop_child(ChildId) ->
     case supervisor:terminate_child(?MODULE, ChildId) of
     case supervisor:terminate_child(?MODULE, ChildId) of
         ok    -> supervisor:delete_child(?MODULE, ChildId);
         ok    -> supervisor:delete_child(?MODULE, ChildId);

+ 1 - 1
src/emqttd_parser.erl

@@ -39,7 +39,7 @@ initial_state(MaxSize) ->
 
 
 %% @doc Parse MQTT Packet
 %% @doc Parse MQTT Packet
 -spec(parse(binary(), {none, pos_integer()} | fun())
 -spec(parse(binary(), {none, pos_integer()} | fun())
-            -> {ok, mqtt_packet()} | {error, any()} | {more, fun()}).
+            -> {ok, mqtt_packet()} | {error, term()} | {more, fun()}).
 parse(<<>>, {none, MaxLen}) ->
 parse(<<>>, {none, MaxLen}) ->
     {more, fun(Bin) -> parse(Bin, {none, MaxLen}) end};
     {more, fun(Bin) -> parse(Bin, {none, MaxLen}) end};
 parse(<<Type:4, Dup:1, QoS:2, Retain:1, Rest/binary>>, {none, Limit}) ->
 parse(<<Type:4, Dup:1, QoS:2, Retain:1, Rest/binary>>, {none, Limit}) ->

+ 4 - 4
src/emqttd_plugins.erl

@@ -47,7 +47,7 @@ init_config(CfgFile) ->
                   end, AppsEnv).
                   end, AppsEnv).
 
 
 %% @doc Load all plugins when the broker started.
 %% @doc Load all plugins when the broker started.
--spec(load() -> list() | {error, any()}).
+-spec(load() -> list() | {error, term()}).
 load() ->
 load() ->
     case emqttd:env(plugins_loaded_file) of
     case emqttd:env(plugins_loaded_file) of
         {ok, File} ->
         {ok, File} ->
@@ -80,7 +80,7 @@ load_plugins(Names, Persistent) ->
     [load_plugin(find_plugin(Name, Plugins), Persistent) || Name <- NeedToLoad].
     [load_plugin(find_plugin(Name, Plugins), Persistent) || Name <- NeedToLoad].
 
 
 %% @doc Unload all plugins before broker stopped.
 %% @doc Unload all plugins before broker stopped.
--spec(unload() -> list() | {error, any()}).
+-spec(unload() -> list() | {error, term()}).
 unload() ->
 unload() ->
     case emqttd:env(plugins_loaded_file) of
     case emqttd:env(plugins_loaded_file) of
         {ok, File} ->
         {ok, File} ->
@@ -119,7 +119,7 @@ plugin(CfgFile) ->
     #mqtt_plugin{name = AppName, version = Ver, descr = Descr}.
     #mqtt_plugin{name = AppName, version = Ver, descr = Descr}.
 
 
 %% @doc Load a Plugin
 %% @doc Load a Plugin
--spec(load(atom()) -> ok | {error, any()}).
+-spec(load(atom()) -> ok | {error, term()}).
 load(PluginName) when is_atom(PluginName) ->
 load(PluginName) when is_atom(PluginName) ->
     case lists:member(PluginName, names(started_app)) of
     case lists:member(PluginName, names(started_app)) of
         true ->
         true ->
@@ -172,7 +172,7 @@ find_plugin(Name, Plugins) ->
     lists:keyfind(Name, 2, Plugins). 
     lists:keyfind(Name, 2, Plugins). 
 
 
 %% @doc UnLoad a Plugin
 %% @doc UnLoad a Plugin
--spec(unload(atom()) -> ok | {error, any()}).
+-spec(unload(atom()) -> ok | {error, term()}).
 unload(PluginName) when is_atom(PluginName) ->
 unload(PluginName) when is_atom(PluginName) ->
     case {lists:member(PluginName, names(started_app)), lists:member(PluginName, names(plugin))} of
     case {lists:member(PluginName, names(started_app)), lists:member(PluginName, names(plugin))} of
         {true, true} ->
         {true, true} ->

+ 2 - 2
src/emqttd_pool_sup.erl

@@ -36,12 +36,12 @@ spec(ChildId, Args) ->
     {ChildId, {?MODULE, start_link, Args},
     {ChildId, {?MODULE, start_link, Args},
         transient, infinity, supervisor, [?MODULE]}.
         transient, infinity, supervisor, [?MODULE]}.
 
 
--spec(start_link(atom() | tuple(), atom(), mfa()) -> {ok, pid()} | {error, any()}).
+-spec(start_link(atom() | tuple(), atom(), mfa()) -> {ok, pid()} | {error, term()}).
 start_link(Pool, Type, MFA) ->
 start_link(Pool, Type, MFA) ->
     Schedulers = erlang:system_info(schedulers),
     Schedulers = erlang:system_info(schedulers),
     start_link(Pool, Type, Schedulers, MFA).
     start_link(Pool, Type, Schedulers, MFA).
 
 
--spec(start_link(atom(), atom(), pos_integer(), mfa()) -> {ok, pid()} | {error, any()}).
+-spec(start_link(atom(), atom(), pos_integer(), mfa()) -> {ok, pid()} | {error, term()}).
 start_link(Pool, Type, Size, MFA) ->
 start_link(Pool, Type, Size, MFA) ->
     supervisor:start_link(?MODULE, [Pool, Type, Size, MFA]).
     supervisor:start_link(?MODULE, [Pool, Type, Size, MFA]).
 
 

+ 1 - 1
src/emqttd_pooler.erl

@@ -40,7 +40,7 @@ start_link() ->
 %% API
 %% API
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
--spec(start_link(atom(), pos_integer()) -> {ok, pid()} | ignore | {error, any()}).
+-spec(start_link(atom(), pos_integer()) -> {ok, pid()} | ignore | {error, term()}).
 start_link(Pool, Id) ->
 start_link(Pool, Id) ->
     gen_server:start_link({local, ?PROC_NAME(?MODULE, Id)}, ?MODULE, [Pool, Id], []).
     gen_server:start_link({local, ?PROC_NAME(?MODULE, Id)}, ?MODULE, [Pool, Id], []).
 
 

+ 1 - 1
src/emqttd_protocol.erl

@@ -124,7 +124,7 @@ session(#proto_state{session = Session}) ->
 %% CONNECT – Client requests a connection to a Server
 %% CONNECT – Client requests a connection to a Server
 
 
 %% A Client can only send the CONNECT Packet once over a Network Connection. 
 %% A Client can only send the CONNECT Packet once over a Network Connection. 
--spec(received(mqtt_packet(), proto_state()) -> {ok, proto_state()} | {error, any()}).
+-spec(received(mqtt_packet(), proto_state()) -> {ok, proto_state()} | {error, term()}).
 received(Packet = ?PACKET(?CONNECT),
 received(Packet = ?PACKET(?CONNECT),
          State = #proto_state{connected = false, stats_data = Stats}) ->
          State = #proto_state{connected = false, stats_data = Stats}) ->
     trace(recv, Packet, State), Stats1 = inc_stats(recv, ?CONNECT, Stats),
     trace(recv, Packet, State), Stats1 = inc_stats(recv, ?CONNECT, Stats),

+ 43 - 34
src/emqttd_pubsub.erl

@@ -46,7 +46,7 @@
 %% Start PubSub
 %% Start PubSub
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
--spec(start_link(atom(), pos_integer(), list()) -> {ok, pid()} | ignore | {error, any()}).
+-spec(start_link(atom(), pos_integer(), list()) -> {ok, pid()} | ignore | {error, term()}).
 start_link(Pool, Id, Env) ->
 start_link(Pool, Id, Env) ->
     gen_server2:start_link({local, ?PROC_NAME(?MODULE, Id)}, ?MODULE, [Pool, Id, Env], []).
     gen_server2:start_link({local, ?PROC_NAME(?MODULE, Id)}, ?MODULE, [Pool, Id, Env], []).
 
 
@@ -54,7 +54,7 @@ start_link(Pool, Id, Env) ->
 %% PubSub API
 %% PubSub API
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
-%% @doc Subscribe a Topic
+%% @doc Subscribe to a Topic
 -spec(subscribe(binary(), emqttd:subscriber(), [emqttd:suboption()]) -> ok).
 -spec(subscribe(binary(), emqttd:subscriber(), [emqttd:suboption()]) -> ok).
 subscribe(Topic, Subscriber, Options) ->
 subscribe(Topic, Subscriber, Options) ->
     call(pick(Topic), {subscribe, Topic, Subscriber, Options}).
     call(pick(Topic), {subscribe, Topic, Subscriber, Options}).
@@ -63,8 +63,8 @@ subscribe(Topic, Subscriber, Options) ->
 async_subscribe(Topic, Subscriber, Options) ->
 async_subscribe(Topic, Subscriber, Options) ->
     cast(pick(Topic), {subscribe, Topic, Subscriber, Options}).
     cast(pick(Topic), {subscribe, Topic, Subscriber, Options}).
 
 
-%% @doc Publish MQTT Message to Topic
--spec(publish(binary(), any()) -> {ok, mqtt_delivery()} | ignore).
+%% @doc Publish MQTT Message to Topic.
+-spec(publish(binary(), mqtt_message()) -> {ok, mqtt_delivery()} | ignore).
 publish(Topic, Msg) ->
 publish(Topic, Msg) ->
     route(lists:append(emqttd_router:match(Topic),
     route(lists:append(emqttd_router:match(Topic),
                        emqttd_router:match_local(Topic)), delivery(Msg)).
                        emqttd_router:match_local(Topic)), delivery(Msg)).
@@ -72,7 +72,7 @@ publish(Topic, Msg) ->
 route([], #mqtt_delivery{message = #mqtt_message{topic = Topic}}) ->
 route([], #mqtt_delivery{message = #mqtt_message{topic = Topic}}) ->
     dropped(Topic), ignore;
     dropped(Topic), ignore;
 
 
-%% Dispatch on the local node
+%% Dispatch on the local node.
 route([#mqtt_route{topic = To, node = Node}],
 route([#mqtt_route{topic = To, node = Node}],
       Delivery = #mqtt_delivery{flows = Flows}) when Node =:= node() ->
       Delivery = #mqtt_delivery{flows = Flows}) when Node =:= node() ->
     dispatch(To, Delivery#mqtt_delivery{flows = [{route, Node, To} | Flows]});
     dispatch(To, Delivery#mqtt_delivery{flows = [{route, Node, To} | Flows]});
@@ -82,8 +82,8 @@ route([#mqtt_route{topic = To, node = Node}], Delivery = #mqtt_delivery{flows =
     forward(Node, To, Delivery#mqtt_delivery{flows = [{route, Node, To}|Flows]});
     forward(Node, To, Delivery#mqtt_delivery{flows = [{route, Node, To}|Flows]});
 
 
 route(Routes, Delivery) ->
 route(Routes, Delivery) ->
-    {ok, lists:foldl(fun(Route, DelAcc) ->
-                    {ok, DelAcc1} = route([Route], DelAcc), DelAcc1
+    {ok, lists:foldl(fun(Route, Acc) ->
+                    {ok, Acc1} = route([Route], Acc), Acc1
             end, Delivery, Routes)}.
             end, Delivery, Routes)}.
 
 
 delivery(Msg) -> #mqtt_delivery{sender = self(), message = Msg, flows = []}.
 delivery(Msg) -> #mqtt_delivery{sender = self(), message = Msg, flows = []}.
@@ -92,7 +92,7 @@ delivery(Msg) -> #mqtt_delivery{sender = self(), message = Msg, flows = []}.
 forward(Node, To, Delivery) ->
 forward(Node, To, Delivery) ->
     rpc:cast(Node, ?PUBSUB, dispatch, [To, Delivery]), {ok, Delivery}.
     rpc:cast(Node, ?PUBSUB, dispatch, [To, Delivery]), {ok, Delivery}.
 
 
-%% @doc Dispatch Message to Subscribers
+%% @doc Dispatch Message to Subscribers.
 -spec(dispatch(binary(), mqtt_delivery()) -> mqtt_delivery()).
 -spec(dispatch(binary(), mqtt_delivery()) -> mqtt_delivery()).
 dispatch(Topic, Delivery = #mqtt_delivery{message = Msg, flows = Flows}) ->
 dispatch(Topic, Delivery = #mqtt_delivery{message = Msg, flows = Flows}) ->
     case subscribers(Topic) of
     case subscribers(Topic) of
@@ -107,16 +107,16 @@ dispatch(Topic, Delivery = #mqtt_delivery{message = Msg, flows = Flows}) ->
             {ok, Delivery#mqtt_delivery{flows = Flows1}}
             {ok, Delivery#mqtt_delivery{flows = Flows1}}
     end.
     end.
 
 
-dispatch(Pid, Topic, Msg) when is_pid(Pid) ->
-    Pid ! {dispatch, Topic, Msg};
-dispatch(SubId, Topic, Msg) when is_binary(SubId) ->
-    emqttd_sm:dispatch(SubId, Topic, Msg);
-dispatch({_Share, [Sub]}, Topic, Msg) ->
+%%TODO: Is SubPid aliving???
+dispatch(SubPid, Topic, Msg) when is_pid(SubPid) ->
+    SubPid ! {dispatch, Topic, Msg};
+dispatch({SubId, SubPid}, Topic, Msg) when is_binary(SubId), is_pid(SubPid) ->
+    SubPid ! {dispatch, Topic, Msg};
+dispatch({{share, _Share}, [Sub]}, Topic, Msg) ->
     dispatch(Sub, Topic, Msg);
     dispatch(Sub, Topic, Msg);
-dispatch({_Share, []}, _Topic, _Msg) ->
+dispatch({{share, _Share}, []}, _Topic, _Msg) ->
     ok;
     ok;
-%%TODO: round-robbin
-dispatch({_Share, Subs}, Topic, Msg) ->
+dispatch({{share, _Share}, Subs}, Topic, Msg) -> %% round-robbin?
     dispatch(lists:nth(rand:uniform(length(Subs)), Subs), Topic, Msg).
     dispatch(lists:nth(rand:uniform(length(Subs)), Subs), Topic, Msg).
 
 
 subscribers(Topic) ->
 subscribers(Topic) ->
@@ -126,8 +126,8 @@ group_by_share([]) -> [];
 
 
 group_by_share(Subscribers) ->
 group_by_share(Subscribers) ->
     {Subs1, Shares1} =
     {Subs1, Shares1} =
-    lists:foldl(fun({Share, Sub}, {Subs, Shares}) ->
-                    {Subs, dict:append(Share, Sub, Shares)};
+    lists:foldl(fun({share, Share, Sub}, {Subs, Shares}) ->
+                    {Subs, dict:append({share, Share}, Sub, Shares)};
                    (Sub, {Subs, Shares}) ->
                    (Sub, {Subs, Shares}) ->
                     {[Sub|Subs], Shares}
                     {[Sub|Subs], Shares}
                 end, {[], dict:new()}, Subscribers),
                 end, {[], dict:new()}, Subscribers),
@@ -155,8 +155,8 @@ call(PubSub, Req) when is_pid(PubSub) ->
 cast(PubSub, Msg) when is_pid(PubSub) ->
 cast(PubSub, Msg) when is_pid(PubSub) ->
     gen_server2:cast(PubSub, Msg).
     gen_server2:cast(PubSub, Msg).
 
 
-pick(Subscriber) ->
-    gproc_pool:pick_worker(pubsub, Subscriber).
+pick(Topic) ->
+    gproc_pool:pick_worker(pubsub, Topic).
 
 
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 %% gen_server Callbacks
 %% gen_server Callbacks
@@ -169,22 +169,22 @@ init([Pool, Id, Env]) ->
 
 
 handle_call({subscribe, Topic, Subscriber, Options}, _From, State) ->
 handle_call({subscribe, Topic, Subscriber, Options}, _From, State) ->
     add_subscriber(Topic, Subscriber, Options),
     add_subscriber(Topic, Subscriber, Options),
-    {reply, ok, setstats(State), hibernate};
+    reply(ok, setstats(State));
 
 
 handle_call({unsubscribe, Topic, Subscriber, Options}, _From, State) ->
 handle_call({unsubscribe, Topic, Subscriber, Options}, _From, State) ->
     del_subscriber(Topic, Subscriber, Options),
     del_subscriber(Topic, Subscriber, Options),
-    {reply, ok, setstats(State), hibernate};
+    reply(ok, setstats(State));
 
 
 handle_call(Req, _From, State) ->
 handle_call(Req, _From, State) ->
     ?UNEXPECTED_REQ(Req, State).
     ?UNEXPECTED_REQ(Req, State).
 
 
 handle_cast({subscribe, Topic, Subscriber, Options}, State) ->
 handle_cast({subscribe, Topic, Subscriber, Options}, State) ->
     add_subscriber(Topic, Subscriber, Options),
     add_subscriber(Topic, Subscriber, Options),
-    {noreply, setstats(State), hibernate};
+    noreply(setstats(State));
 
 
 handle_cast({unsubscribe, Topic, Subscriber, Options}, State) ->
 handle_cast({unsubscribe, Topic, Subscriber, Options}, State) ->
     del_subscriber(Topic, Subscriber, Options),
     del_subscriber(Topic, Subscriber, Options),
-    {noreply, setstats(State), hibernate};
+    noreply(setstats(State));
 
 
 handle_cast(Msg, State) ->
 handle_cast(Msg, State) ->
     ?UNEXPECTED_MSG(Msg, State).
     ?UNEXPECTED_MSG(Msg, State).
@@ -205,39 +205,48 @@ code_change(_OldVsn, State, _Extra) ->
 add_subscriber(Topic, Subscriber, Options) ->
 add_subscriber(Topic, Subscriber, Options) ->
     Share = proplists:get_value(share, Options),
     Share = proplists:get_value(share, Options),
     case ?is_local(Options) of
     case ?is_local(Options) of
-        false -> add_subscriber_(Share, Topic, Subscriber);
-        true  -> add_local_subscriber_(Share, Topic, Subscriber)
+        false -> add_global_subscriber(Share, Topic, Subscriber);
+        true  -> add_local_subscriber(Share, Topic, Subscriber)
     end.
     end.
 
 
-add_subscriber_(Share, Topic, Subscriber) ->
-    (not ets:member(mqtt_subscriber, Topic)) andalso emqttd_router:add_route(Topic),
+add_global_subscriber(Share, Topic, Subscriber) ->
+    case ets:member(mqtt_subscriber, Topic) and emqttd_router:has_route(Topic) of
+        true  -> ok;
+        false -> emqttd_router:add_route(Topic)
+    end,
     ets:insert(mqtt_subscriber, {Topic, shared(Share, Subscriber)}).
     ets:insert(mqtt_subscriber, {Topic, shared(Share, Subscriber)}).
 
 
-add_local_subscriber_(Share, Topic, Subscriber) ->
+add_local_subscriber(Share, Topic, Subscriber) ->
     (not ets:member(mqtt_subscriber, {local, Topic})) andalso emqttd_router:add_local_route(Topic),
     (not ets:member(mqtt_subscriber, {local, Topic})) andalso emqttd_router:add_local_route(Topic),
     ets:insert(mqtt_subscriber, {{local, Topic}, shared(Share, Subscriber)}).
     ets:insert(mqtt_subscriber, {{local, Topic}, shared(Share, Subscriber)}).
 
 
 del_subscriber(Topic, Subscriber, Options) ->
 del_subscriber(Topic, Subscriber, Options) ->
     Share = proplists:get_value(share, Options),
     Share = proplists:get_value(share, Options),
     case ?is_local(Options) of
     case ?is_local(Options) of
-        false -> del_subscriber_(Share, Topic, Subscriber);
-        true  -> del_local_subscriber_(Share, Topic, Subscriber)
+        false -> del_global_subscriber(Share, Topic, Subscriber);
+        true  -> del_local_subscriber(Share, Topic, Subscriber)
     end.
     end.
 
 
-del_subscriber_(Share, Topic, Subscriber) ->
+del_global_subscriber(Share, Topic, Subscriber) ->
     ets:delete_object(mqtt_subscriber, {Topic, shared(Share, Subscriber)}),
     ets:delete_object(mqtt_subscriber, {Topic, shared(Share, Subscriber)}),
     (not ets:member(mqtt_subscriber, Topic)) andalso emqttd_router:del_route(Topic).
     (not ets:member(mqtt_subscriber, Topic)) andalso emqttd_router:del_route(Topic).
 
 
-del_local_subscriber_(Share, Topic, Subscriber) ->
+del_local_subscriber(Share, Topic, Subscriber) ->
     ets:delete_object(mqtt_subscriber, {{local, Topic}, shared(Share, Subscriber)}),
     ets:delete_object(mqtt_subscriber, {{local, Topic}, shared(Share, Subscriber)}),
     (not ets:member(mqtt_subscriber, {local, Topic})) andalso emqttd_router:del_local_route(Topic).
     (not ets:member(mqtt_subscriber, {local, Topic})) andalso emqttd_router:del_local_route(Topic).
 
 
 shared(undefined, Subscriber) ->
 shared(undefined, Subscriber) ->
     Subscriber;
     Subscriber;
 shared(Share, Subscriber) ->
 shared(Share, Subscriber) ->
-    {Share, Subscriber}.
+    {share, Share, Subscriber}.
 
 
 setstats(State) ->
 setstats(State) ->
     emqttd_stats:setstats('subscribers/count', 'subscribers/max', ets:info(mqtt_subscriber, size)),
     emqttd_stats:setstats('subscribers/count', 'subscribers/max', ets:info(mqtt_subscriber, size)),
     State.
     State.
 
 
+reply(Reply, State) ->
+    {reply, Reply, State, hibernate}.
+
+noreply(State) ->
+    {noreply, State, hibernate}.
+

+ 100 - 97
src/emqttd_router.erl

@@ -28,15 +28,17 @@
 -boot_mnesia({mnesia, [boot]}).
 -boot_mnesia({mnesia, [boot]}).
 -copy_mnesia({mnesia, [copy]}).
 -copy_mnesia({mnesia, [copy]}).
 
 
-%% Start/Stop
--export([start_link/0, topics/0, local_topics/0, stop/0]).
+-export([start_link/0, topics/0, local_topics/0]).
+
+%% For eunit tests
+-export([start/0, stop/0]).
 
 
 %% Route APIs
 %% Route APIs
--export([add_route/1, add_route/2, add_routes/1, match/1, print/1,
-         del_route/1, del_route/2, del_routes/1, has_route/1]).
+-export([add_route/1, del_route/1, match/1, print/1, has_route/1]).
 
 
 %% Local Route API
 %% Local Route API
--export([add_local_route/1, del_local_route/1, match_local/1]).
+-export([get_local_routes/0, add_local_route/1, match_local/1,
+         del_local_route/1, clean_local_routes/0]).
 
 
 %% gen_server Function Exports
 %% gen_server Function Exports
 -export([init/1, handle_call/3, handle_cast/2, handle_info/2,
 -export([init/1, handle_call/3, handle_cast/2, handle_info/2,
@@ -55,10 +57,6 @@
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
 mnesia(boot) ->
 mnesia(boot) ->
-    ok = ekka_mnesia:create_table(mqtt_topic, [
-                {ram_copies, [node()]},
-                {record_name, mqtt_topic},
-                {attributes, record_info(fields, mqtt_topic)}]),
     ok = ekka_mnesia:create_table(mqtt_route, [
     ok = ekka_mnesia:create_table(mqtt_route, [
                 {type, bag},
                 {type, bag},
                 {ram_copies, [node()]},
                 {ram_copies, [node()]},
@@ -66,7 +64,6 @@ mnesia(boot) ->
                 {attributes, record_info(fields, mqtt_route)}]);
                 {attributes, record_info(fields, mqtt_route)}]);
 
 
 mnesia(copy) ->
 mnesia(copy) ->
-    ok = ekka_mnesia:copy_table(mqtt_topic),
     ok = ekka_mnesia:copy_table(mqtt_route, ram_copies).
     ok = ekka_mnesia:copy_table(mqtt_route, ram_copies).
 
 
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
@@ -77,19 +74,26 @@ start_link() ->
     gen_server:start_link({local, ?ROUTER}, ?MODULE, [], []).
     gen_server:start_link({local, ?ROUTER}, ?MODULE, [], []).
 
 
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
-%% API
+%% Topics
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
+-spec(topics() -> list(binary())).
 topics() ->
 topics() ->
     mnesia:dirty_all_keys(mqtt_route).
     mnesia:dirty_all_keys(mqtt_route).
 
 
+-spec(local_topics() -> list(binary())).
 local_topics() ->
 local_topics() ->
     ets:select(mqtt_local_route, [{{'$1', '_'}, [], ['$1']}]).
     ets:select(mqtt_local_route, [{{'$1', '_'}, [], ['$1']}]).
 
 
+%%--------------------------------------------------------------------
+%% Match API
+%%--------------------------------------------------------------------
+
 %% @doc Match Routes.
 %% @doc Match Routes.
 -spec(match(Topic:: binary()) -> [mqtt_route()]).
 -spec(match(Topic:: binary()) -> [mqtt_route()]).
 match(Topic) when is_binary(Topic) ->
 match(Topic) when is_binary(Topic) ->
-    Matched = mnesia:async_dirty(fun emqttd_trie:match/1, [Topic]),
+    %% Optimize: ets???
+    Matched = mnesia:ets(fun emqttd_trie:match/1, [Topic]),
     %% Optimize: route table will be replicated to all nodes.
     %% Optimize: route table will be replicated to all nodes.
     lists:append([ets:lookup(mqtt_route, To) || To <- [Topic | Matched]]).
     lists:append([ets:lookup(mqtt_route, To) || To <- [Topic | Matched]]).
 
 
@@ -99,93 +103,68 @@ print(Topic) ->
     [io:format("~s -> ~s~n", [To, Node]) ||
     [io:format("~s -> ~s~n", [To, Node]) ||
         #mqtt_route{topic = To, node = Node} <- match(Topic)].
         #mqtt_route{topic = To, node = Node} <- match(Topic)].
 
 
-%% @doc Add Route
--spec(add_route(binary() | mqtt_route()) -> ok | {error, Reason :: any()}).
+%%--------------------------------------------------------------------
+%% Route Management API
+%%--------------------------------------------------------------------
+
+%% @doc Add Route.
+-spec(add_route(binary() | mqtt_route()) -> ok | {error, Reason :: term()}).
 add_route(Topic) when is_binary(Topic) ->
 add_route(Topic) when is_binary(Topic) ->
     add_route(#mqtt_route{topic = Topic, node = node()});
     add_route(#mqtt_route{topic = Topic, node = node()});
-add_route(Route) when is_record(Route, mqtt_route) ->
-    add_routes([Route]).
-
--spec(add_route(Topic :: binary(), Node :: node()) -> ok | {error, Reason :: any()}).
-add_route(Topic, Node) when is_binary(Topic), is_atom(Node) ->
-    add_route(#mqtt_route{topic = Topic, node = Node}).
-
-%% @doc Add Routes
--spec(add_routes([mqtt_route()]) -> ok | {error, Reason :: any()}).
-add_routes(Routes) ->
-    AddFun = fun() -> [add_route_(Route) || Route <- Routes] end,
-    case mnesia:is_transaction() of
-        true  -> AddFun();
-        false -> trans(AddFun)
+add_route(Route = #mqtt_route{topic = Topic}) ->
+    case emqttd_topic:wildcard(Topic) of
+        true  -> case mnesia:is_transaction() of
+                     true  -> add_trie_route(Route);
+                     false -> trans(fun add_trie_route/1, [Route])
+                 end;
+        false -> add_direct_route(Route)
     end.
     end.
 
 
-%% @private
-add_route_(Route = #mqtt_route{topic = Topic}) ->
+add_direct_route(Route) ->
+    mnesia:async_dirty(fun mnesia:write/1, [Route]).
+
+add_trie_route(Route = #mqtt_route{topic = Topic}) ->
     case mnesia:wread({mqtt_route, Topic}) of
     case mnesia:wread({mqtt_route, Topic}) of
-        [] ->
-            case emqttd_topic:wildcard(Topic) of
-                true  -> emqttd_trie:insert(Topic);
-                false -> ok
-            end,
-            mnesia:write(Route),
-            mnesia:write(#mqtt_topic{topic = Topic});
-        Records ->
-            case lists:member(Route, Records) of
-                true  -> ok;
-                false -> mnesia:write(Route)
-            end
-    end.
+        [] -> emqttd_trie:insert(Topic);
+        _  -> ok
+    end,
+    mnesia:write(Route).
 
 
 %% @doc Delete Route
 %% @doc Delete Route
--spec(del_route(binary() | mqtt_route()) -> ok | {error, Reason :: any()}).
+-spec(del_route(binary() | mqtt_route()) -> ok | {error, Reason :: term()}).
 del_route(Topic) when is_binary(Topic) ->
 del_route(Topic) when is_binary(Topic) ->
     del_route(#mqtt_route{topic = Topic, node = node()});
     del_route(#mqtt_route{topic = Topic, node = node()});
-del_route(Route) when is_record(Route, mqtt_route) ->
-    del_routes([Route]).
-
--spec(del_route(Topic :: binary(), Node :: node()) -> ok | {error, Reason :: any()}).
-del_route(Topic, Node) when is_binary(Topic), is_atom(Node) ->
-    del_route(#mqtt_route{topic = Topic, node = Node}).
-
-%% @doc Delete Routes
--spec(del_routes([mqtt_route()]) -> ok | {error, any()}).
-del_routes(Routes) ->
-    DelFun = fun() -> [del_route_(Route) || Route <- Routes] end,
-    case mnesia:is_transaction() of
-        true  -> DelFun();
-        false -> trans(DelFun)
+del_route(Route = #mqtt_route{topic = Topic}) ->
+    case emqttd_topic:wildcard(Topic) of
+        true  -> case mnesia:is_transaction() of
+                     true  -> del_trie_route(Route);
+                     false -> trans(fun del_trie_route/1, [Route])
+                 end;
+        false -> del_direct_route(Route)
     end.
     end.
 
 
-del_route_(Route = #mqtt_route{topic = Topic}) ->
+del_direct_route(Route) ->
+    mnesia:async_dirty(fun mnesia:delete_object/1, [Route]).
+
+del_trie_route(Route = #mqtt_route{topic = Topic}) ->
     case mnesia:wread({mqtt_route, Topic}) of
     case mnesia:wread({mqtt_route, Topic}) of
-        [] ->
-            ok;
-        [Route] ->
-            %% Remove route and trie
-            mnesia:delete_object(Route),
-            case emqttd_topic:wildcard(Topic) of
-                true  -> emqttd_trie:delete(Topic);
-                false -> ok
-            end,
-            mnesia:delete({mqtt_topic, Topic});
-        _More ->
-            %% Remove route only
-            mnesia:delete_object(Route)
+        [Route] -> %% Remove route and trie
+                   mnesia:delete_object(Route),
+                   emqttd_trie:delete(Topic);
+        [_|_]   -> %% Remove route only
+                   mnesia:delete_object(Route);
+        []      -> ok
     end.
     end.
 
 
-%% @doc Has Route?
+%% @doc Has route?
 -spec(has_route(binary()) -> boolean()).
 -spec(has_route(binary()) -> boolean()).
-has_route(Topic) ->
-    Routes = case mnesia:is_transaction() of
-                 true  -> mnesia:read(mqtt_route, Topic);
-                 false -> mnesia:dirty_read(mqtt_route, Topic)
-             end,
-    length(Routes) > 0.
+has_route(Topic) when is_binary(Topic) ->
+    ets:member(mqtt_route, Topic).
 
 
 %% @private
 %% @private
--spec(trans(function()) -> ok | {error, any()}).
-trans(Fun) ->
-    case mnesia:transaction(Fun) of
+-spec(trans(function(), list(any())) -> ok | {error, term()}).
+trans(Fun, Args) ->
+    case mnesia:transaction(Fun, Args) of
         {atomic, _}      -> ok;
         {atomic, _}      -> ok;
         {aborted, Error} -> {error, Error}
         {aborted, Error} -> {error, Error}
     end.
     end.
@@ -194,24 +173,44 @@ trans(Fun) ->
 %% Local Route API
 %% Local Route API
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
+-spec(get_local_routes() -> list({binary(), node()})).
+get_local_routes() ->
+    ets:tab2list(mqtt_local_route).
+
 -spec(add_local_route(binary()) -> ok).
 -spec(add_local_route(binary()) -> ok).
 add_local_route(Topic) ->
 add_local_route(Topic) ->
-    gen_server:cast(?ROUTER, {add_local_route, Topic}).
+    gen_server:call(?ROUTER, {add_local_route, Topic}).
     
     
 -spec(del_local_route(binary()) -> ok).
 -spec(del_local_route(binary()) -> ok).
 del_local_route(Topic) ->
 del_local_route(Topic) ->
-    gen_server:cast(?ROUTER, {del_local_route, Topic}).
+    gen_server:call(?ROUTER, {del_local_route, Topic}).
     
     
 -spec(match_local(binary()) -> [mqtt_route()]).
 -spec(match_local(binary()) -> [mqtt_route()]).
 match_local(Name) ->
 match_local(Name) ->
-    [#mqtt_route{topic = {local, Filter}, node = Node}
-        || {Filter, Node} <- ets:tab2list(mqtt_local_route),
-           emqttd_topic:match(Name, Filter)].
+    case ets:info(mqtt_local_route, size) of
+        0 -> [];
+        _ -> ets:foldl(
+               fun({Filter, Node}, Matched) ->
+                   case emqttd_topic:match(Name, Filter) of
+                       true  -> [#mqtt_route{topic = {local, Filter}, node = Node} | Matched];
+                       false -> Matched
+                   end
+               end, [], mqtt_local_route)
+    end.
+
+-spec(clean_local_routes() -> ok).
+clean_local_routes() ->
+    gen_server:call(?ROUTER, clean_local_routes).
 
 
 dump() ->
 dump() ->
     [{route, ets:tab2list(mqtt_route)}, {local_route, ets:tab2list(mqtt_local_route)}].
     [{route, ets:tab2list(mqtt_route)}, {local_route, ets:tab2list(mqtt_local_route)}].
 
 
-stop() -> gen_server:call(?ROUTER, stop).
+%% For unit test.
+start() ->
+    gen_server:start({local, ?ROUTER}, ?MODULE, [], []).
+
+stop() ->
+    gen_server:call(?ROUTER, stop).
 
 
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 %% gen_server Callbacks
 %% gen_server Callbacks
@@ -223,21 +222,25 @@ init([]) ->
     {ok, TRef} = timer:send_interval(timer:seconds(1), stats),
     {ok, TRef} = timer:send_interval(timer:seconds(1), stats),
     {ok, #state{stats_timer = TRef}}.
     {ok, #state{stats_timer = TRef}}.
 
 
+handle_call({add_local_route, Topic}, _From, State) ->
+    %% why node()...?
+    ets:insert(mqtt_local_route, {Topic, node()}),
+    {reply, ok, State};
+
+handle_call({del_local_route, Topic}, _From, State) ->
+    ets:delete(mqtt_local_route, Topic),
+    {reply, ok, State};
+
+handle_call(clean_local_routes, _From, State) ->
+    ets:delete_all_objects(mqtt_local_route),
+    {reply, ok, State};
+
 handle_call(stop, _From, State) ->
 handle_call(stop, _From, State) ->
     {stop, normal, ok, State};
     {stop, normal, ok, State};
 
 
 handle_call(_Req, _From, State) ->
 handle_call(_Req, _From, State) ->
     {reply, ignore, State}.
     {reply, ignore, State}.
 
 
-handle_cast({add_local_route, Topic}, State) ->
-    %% why node()...?
-    ets:insert(mqtt_local_route, {Topic, node()}),
-    {noreply, State};
-
-handle_cast({del_local_route, Topic}, State) ->
-    ets:delete(mqtt_local_route, Topic),
-    {noreply, State};
-
 handle_cast(_Msg, State) ->
 handle_cast(_Msg, State) ->
     {noreply, State}.
     {noreply, State}.
 
 

+ 112 - 90
src/emqttd_server.erl

@@ -37,8 +37,7 @@
          async_unsubscribe/1, async_unsubscribe/2]).
          async_unsubscribe/1, async_unsubscribe/2]).
 
 
 %% Management API.
 %% Management API.
--export([setqos/3, subscriptions/1, subscribers/1, is_subscribed/2,
-         subscriber_down/1]).
+-export([setqos/3, subscriptions/1, subscribers/1, subscribed/2]).
 
 
 %% Debug API
 %% Debug API
 -export([dump/0]).
 -export([dump/0]).
@@ -47,10 +46,10 @@
 -export([init/1, handle_call/3, handle_cast/2, handle_info/2,
 -export([init/1, handle_call/3, handle_cast/2, handle_info/2,
          terminate/2, code_change/3]).
          terminate/2, code_change/3]).
 
 
--record(state, {pool, id, env, submon :: emqttd_pmon:pmon()}).
+-record(state, {pool, id, env, subids :: map(), submon :: emqttd_pmon:pmon()}).
 
 
-%% @doc Start server
--spec(start_link(atom(), pos_integer(), list()) -> {ok, pid()} | ignore | {error, any()}).
+%% @doc Start the server
+-spec(start_link(atom(), pos_integer(), list()) -> {ok, pid()} | ignore | {error, term()}).
 start_link(Pool, Id, Env) ->
 start_link(Pool, Id, Env) ->
     gen_server2:start_link({local, ?PROC_NAME(?MODULE, Id)}, ?MODULE, [Pool, Id, Env], []).
     gen_server2:start_link({local, ?PROC_NAME(?MODULE, Id)}, ?MODULE, [Pool, Id, Env], []).
 
 
@@ -58,21 +57,21 @@ start_link(Pool, Id, Env) ->
 %% PubSub API
 %% PubSub API
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
-%% @doc Subscribe a Topic
--spec(subscribe(binary()) -> ok | emqttd:pubsub_error()).
+%% @doc Subscribe to a Topic.
+-spec(subscribe(binary()) -> ok | {error, term()}).
 subscribe(Topic) when is_binary(Topic) ->
 subscribe(Topic) when is_binary(Topic) ->
     subscribe(Topic, self()).
     subscribe(Topic, self()).
 
 
--spec(subscribe(binary(), emqttd:subscriber()) -> ok | emqttd:pubsub_error()).
+-spec(subscribe(binary(), emqttd:subscriber()) -> ok | {error, term()}).
 subscribe(Topic, Subscriber) when is_binary(Topic) ->
 subscribe(Topic, Subscriber) when is_binary(Topic) ->
     subscribe(Topic, Subscriber, []).
     subscribe(Topic, Subscriber, []).
 
 
 -spec(subscribe(binary(), emqttd:subscriber(), [emqttd:suboption()]) ->
 -spec(subscribe(binary(), emqttd:subscriber(), [emqttd:suboption()]) ->
-      ok | emqttd:pubsub_error()).
+      ok | {error, term()}).
 subscribe(Topic, Subscriber, Options) when is_binary(Topic) ->
 subscribe(Topic, Subscriber, Options) when is_binary(Topic) ->
-    call(pick(Subscriber), {subscribe, Topic, Subscriber, Options}).
+    call(pick(Subscriber), {subscribe, Topic, with_subpid(Subscriber), Options}).
 
 
-%% @doc Subscribe a Topic Asynchronously
+%% @doc Subscribe to a Topic asynchronously.
 -spec(async_subscribe(binary()) -> ok).
 -spec(async_subscribe(binary()) -> ok).
 async_subscribe(Topic) when is_binary(Topic) ->
 async_subscribe(Topic) when is_binary(Topic) ->
     async_subscribe(Topic, self()).
     async_subscribe(Topic, self()).
@@ -83,7 +82,7 @@ async_subscribe(Topic, Subscriber) when is_binary(Topic) ->
 
 
 -spec(async_subscribe(binary(), emqttd:subscriber(), [emqttd:suboption()]) -> ok).
 -spec(async_subscribe(binary(), emqttd:subscriber(), [emqttd:suboption()]) -> ok).
 async_subscribe(Topic, Subscriber, Options) when is_binary(Topic) ->
 async_subscribe(Topic, Subscriber, Options) when is_binary(Topic) ->
-    cast(pick(Subscriber), {subscribe, Topic, Subscriber, Options}).
+    cast(pick(Subscriber), {subscribe, Topic, with_subpid(Subscriber), Options}).
 
 
 %% @doc Publish message to Topic.
 %% @doc Publish message to Topic.
 -spec(publish(mqtt_message()) -> {ok, mqtt_delivery()} | ignore).
 -spec(publish(mqtt_message()) -> {ok, mqtt_delivery()} | ignore).
@@ -109,14 +108,14 @@ trace(publish, From, #mqtt_message{topic = Topic, payload = Payload}) ->
                 "~s PUBLISH to ~s: ~p", [From, Topic, Payload]).
                 "~s PUBLISH to ~s: ~p", [From, Topic, Payload]).
 
 
 %% @doc Unsubscribe
 %% @doc Unsubscribe
--spec(unsubscribe(binary()) -> ok | emqttd:pubsub_error()).
+-spec(unsubscribe(binary()) -> ok | {error, term()}).
 unsubscribe(Topic) when is_binary(Topic) ->
 unsubscribe(Topic) when is_binary(Topic) ->
     unsubscribe(Topic, self()).
     unsubscribe(Topic, self()).
 
 
 %% @doc Unsubscribe
 %% @doc Unsubscribe
--spec(unsubscribe(binary(), emqttd:subscriber()) -> ok | emqttd:pubsub_error()).
+-spec(unsubscribe(binary(), emqttd:subscriber()) -> ok | {error, term()}).
 unsubscribe(Topic, Subscriber) when is_binary(Topic) ->
 unsubscribe(Topic, Subscriber) when is_binary(Topic) ->
-    call(pick(Subscriber), {unsubscribe, Topic, Subscriber}).
+    call(pick(Subscriber), {unsubscribe, Topic, with_subpid(Subscriber)}).
 
 
 %% @doc Async Unsubscribe
 %% @doc Async Unsubscribe
 -spec(async_unsubscribe(binary()) -> ok).
 -spec(async_unsubscribe(binary()) -> ok).
@@ -125,32 +124,47 @@ async_unsubscribe(Topic) when is_binary(Topic) ->
 
 
 -spec(async_unsubscribe(binary(), emqttd:subscriber()) -> ok).
 -spec(async_unsubscribe(binary(), emqttd:subscriber()) -> ok).
 async_unsubscribe(Topic, Subscriber) when is_binary(Topic) ->
 async_unsubscribe(Topic, Subscriber) when is_binary(Topic) ->
-    cast(pick(Subscriber), {unsubscribe, Topic, Subscriber}).
+    cast(pick(Subscriber), {unsubscribe, Topic, with_subpid(Subscriber)}).
 
 
+-spec(setqos(binary(), emqttd:subscriber(), mqtt_qos()) -> ok).
 setqos(Topic, Subscriber, Qos) when is_binary(Topic) ->
 setqos(Topic, Subscriber, Qos) when is_binary(Topic) ->
-    call(pick(Subscriber), {setqos, Topic, Subscriber, Qos}).
-
--spec(subscriptions(emqttd:subscriber()) -> [{binary(), binary(), list(emqttd:suboption())}]).
-subscriptions(Subscriber) ->
-    lists:map(fun({_, {_Share, Topic}}) ->
-                subscription(Topic, Subscriber);
-                 ({_, Topic}) ->
-                subscription(Topic, Subscriber)
-        end, ets:lookup(mqtt_subscription, Subscriber)).
-
-subscription(Topic, Subscriber) ->
-    {Topic, Subscriber, ets:lookup_element(mqtt_subproperty, {Topic, Subscriber}, 2)}.
-
-subscribers(Topic) ->
+    call(pick(Subscriber), {setqos, Topic, with_subpid(Subscriber), Qos}).
+
+with_subpid(SubPid) when is_pid(SubPid) ->
+    SubPid;
+with_subpid(SubId) when is_binary(SubId) ->
+    {SubId, self()};
+with_subpid({SubId, SubPid}) when is_binary(SubId), is_pid(SubPid) ->
+    {SubId, SubPid}.
+
+-spec(subscriptions(emqttd:subscriber()) -> [{emqttd:subscriber(), binary(), list(emqttd:suboption())}]).
+subscriptions(SubPid) when is_pid(SubPid) ->
+    with_subproperty(ets:lookup(mqtt_subscription, SubPid));
+
+subscriptions(SubId) when is_binary(SubId) ->
+    with_subproperty(ets:match_object(mqtt_subscription, {{SubId, '_'}, '_'}));
+
+subscriptions({SubId, SubPid}) when is_binary(SubId), is_pid(SubPid) ->
+    with_subproperty(ets:lookup(mqtt_subscription, {SubId, SubPid})).
+
+with_subproperty({Subscriber, {share, _Share, Topic}}) ->
+    with_subproperty({Subscriber, Topic});
+with_subproperty({Subscriber, Topic}) ->
+    {Subscriber, Topic, ets:lookup_element(mqtt_subproperty, {Topic, Subscriber}, 2)};
+with_subproperty(Subscriptions) when is_list(Subscriptions) ->
+    [with_subproperty(Subscription) || Subscription <- Subscriptions].
+
+-spec(subscribers(binary()) -> list(emqttd:subscriber())).
+subscribers(Topic) when is_binary(Topic) ->
     emqttd_pubsub:subscribers(Topic).
     emqttd_pubsub:subscribers(Topic).
 
 
--spec(is_subscribed(binary(), emqttd:subscriber()) -> boolean()).
-is_subscribed(Topic, Subscriber) when is_binary(Topic) ->
-    ets:member(mqtt_subproperty, {Topic, Subscriber}).
-
--spec(subscriber_down(emqttd:subscriber()) -> ok).
-subscriber_down(Subscriber) ->
-    cast(pick(Subscriber), {subscriber_down, Subscriber}).
+-spec(subscribed(binary(), emqttd:subscriber()) -> boolean()).
+subscribed(Topic, SubPid) when is_binary(Topic), is_pid(SubPid) ->
+    ets:member(mqtt_subproperty, {Topic, SubPid});
+subscribed(Topic, SubId) when is_binary(Topic), is_binary(SubId) ->
+    length(ets:match_object(mqtt_subproperty, {{Topic, {SubId, '_'}}, '_'}, 1)) == 1;
+subscribed(Topic, {SubId, SubPid}) when is_binary(Topic), is_binary(SubId), is_pid(SubPid) ->
+    ets:member(mqtt_subproperty, {Topic, {SubId, SubPid}}).
 
 
 call(Server, Req) ->
 call(Server, Req) ->
     gen_server2:call(Server, Req, infinity).
     gen_server2:call(Server, Req, infinity).
@@ -158,8 +172,12 @@ call(Server, Req) ->
 cast(Server, Msg) when is_pid(Server) ->
 cast(Server, Msg) when is_pid(Server) ->
     gen_server2:cast(Server, Msg).
     gen_server2:cast(Server, Msg).
 
 
-pick(Subscriber) ->
-    gproc_pool:pick_worker(server, Subscriber).
+pick(SubPid) when is_pid(SubPid) ->
+    gproc_pool:pick_worker(server, SubPid);
+pick(SubId) when is_binary(SubId) ->
+    gproc_pool:pick_worker(server, SubId);
+pick({SubId, SubPid}) when is_binary(SubId), is_pid(SubPid) ->
+    pick(SubId).
 
 
 dump() ->
 dump() ->
     [{Tab, ets:tab2list(Tab)} || Tab <- [mqtt_subproperty, mqtt_subscription, mqtt_subscriber]].
     [{Tab, ets:tab2list(Tab)} || Tab <- [mqtt_subproperty, mqtt_subscription, mqtt_subscriber]].
@@ -170,18 +188,20 @@ dump() ->
 
 
 init([Pool, Id, Env]) ->
 init([Pool, Id, Env]) ->
     ?GPROC_POOL(join, Pool, Id),
     ?GPROC_POOL(join, Pool, Id),
-    {ok, #state{pool = Pool, id = Id, env = Env, submon = emqttd_pmon:new()}}.
+    State = #state{pool = Pool, id = Id, env = Env,
+                   subids = #{}, submon = emqttd_pmon:new()},
+    {ok, State, hibernate, {backoff, 2000, 2000, 20000}}.
 
 
 handle_call({subscribe, Topic, Subscriber, Options}, _From, State) ->
 handle_call({subscribe, Topic, Subscriber, Options}, _From, State) ->
-    case do_subscribe_(Topic, Subscriber, Options, State) of
-        {ok, NewState} -> {reply, ok, setstats(NewState)};
-        {error, Error} -> {reply, {error, Error}, State}
+    case do_subscribe(Topic, Subscriber, Options, State) of
+        {ok, NewState} -> reply(ok, setstats(NewState));
+        {error, Error} -> reply({error, Error}, State)
     end;
     end;
 
 
 handle_call({unsubscribe, Topic, Subscriber}, _From, State) ->
 handle_call({unsubscribe, Topic, Subscriber}, _From, State) ->
-    case do_unsubscribe_(Topic, Subscriber, State) of
-        {ok, NewState} -> {reply, ok, setstats(NewState), hibernate};
-        {error, Error} -> {reply, {error, Error}, State}
+    case do_unsubscribe(Topic, Subscriber, State) of
+        {ok, NewState} -> reply(ok, setstats(NewState));
+        {error, Error} -> reply({error, Error}, State)
     end;
     end;
 
 
 handle_call({setqos, Topic, Subscriber, Qos}, _From, State) ->
 handle_call({setqos, Topic, Subscriber, Qos}, _From, State) ->
@@ -190,36 +210,37 @@ handle_call({setqos, Topic, Subscriber, Qos}, _From, State) ->
         [{_, Opts}] ->
         [{_, Opts}] ->
             Opts1 = lists:ukeymerge(1, [{qos, Qos}], Opts),
             Opts1 = lists:ukeymerge(1, [{qos, Qos}], Opts),
             ets:insert(mqtt_subproperty, {Key, Opts1}),
             ets:insert(mqtt_subproperty, {Key, Opts1}),
-            {reply, ok, State};
+            reply(ok, State);
         [] ->
         [] ->
-            {reply, {error, {subscription_not_found, Topic}}, State}
+            reply({error, {subscription_not_found, Topic}}, State)
     end;
     end;
 
 
 handle_call(Req, _From, State) ->
 handle_call(Req, _From, State) ->
     ?UNEXPECTED_REQ(Req, State).
     ?UNEXPECTED_REQ(Req, State).
 
 
 handle_cast({subscribe, Topic, Subscriber, Options}, State) ->
 handle_cast({subscribe, Topic, Subscriber, Options}, State) ->
-    case do_subscribe_(Topic, Subscriber, Options, State) of
-        {ok, NewState}  -> {noreply, setstats(NewState)};
-        {error, _Error} -> {noreply, State}
+    case do_subscribe(Topic, Subscriber, Options, State) of
+        {ok, NewState}  -> noreply(setstats(NewState));
+        {error, _Error} -> noreply(State)
     end;
     end;
 
 
 handle_cast({unsubscribe, Topic, Subscriber}, State) ->
 handle_cast({unsubscribe, Topic, Subscriber}, State) ->
-    case do_unsubscribe_(Topic, Subscriber, State) of
-        {ok, NewState}  -> {noreply, setstats(NewState), hibernate};
-        {error, _Error} -> {noreply, State}
+    case do_unsubscribe(Topic, Subscriber, State) of
+        {ok, NewState}  -> noreply(setstats(NewState));
+        {error, _Error} -> noreply(State)
     end;
     end;
 
 
-handle_cast({subscriber_down, Subscriber}, State) ->
-    subscriber_down_(Subscriber),
-    {noreply, setstats(State)};
-
 handle_cast(Msg, State) ->
 handle_cast(Msg, State) ->
     ?UNEXPECTED_MSG(Msg, State).
     ?UNEXPECTED_MSG(Msg, State).
 
 
-handle_info({'DOWN', _MRef, process, DownPid, _Reason}, State = #state{submon = PM}) ->
-    subscriber_down_(DownPid),
-    {noreply, setstats(State#state{submon = PM:erase(DownPid)}), hibernate};
+handle_info({'DOWN', _MRef, process, DownPid, _Reason}, State = #state{subids = SubIds}) ->
+    case maps:find(DownPid, SubIds) of
+        {ok, SubId} ->
+            clean_subscriber({SubId, DownPid});
+        error ->
+            clean_subscriber(DownPid)
+    end,
+    noreply(setstats(demonitor_subscriber(DownPid, State)));
 
 
 handle_info(Info, State) ->
 handle_info(Info, State) ->
     ?UNEXPECTED_INFO(Info, State).
     ?UNEXPECTED_INFO(Info, State).
@@ -234,62 +255,54 @@ code_change(_OldVsn, State, _Extra) ->
 %% Internal Functions
 %% Internal Functions
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
-do_subscribe_(Topic, Subscriber, Options, State) ->
+do_subscribe(Topic, Subscriber, Options, State) ->
     case ets:lookup(mqtt_subproperty, {Topic, Subscriber}) of
     case ets:lookup(mqtt_subproperty, {Topic, Subscriber}) of
         [] ->
         [] ->
             emqttd_pubsub:async_subscribe(Topic, Subscriber, Options),
             emqttd_pubsub:async_subscribe(Topic, Subscriber, Options),
             Share = proplists:get_value(share, Options),
             Share = proplists:get_value(share, Options),
-            add_subscription_(Share, Subscriber, Topic),
+            add_subscription(Share, Subscriber, Topic),
             ets:insert(mqtt_subproperty, {{Topic, Subscriber}, Options}),
             ets:insert(mqtt_subproperty, {{Topic, Subscriber}, Options}),
-            {ok, monitor_subpid(Subscriber, State)};
+            {ok, monitor_subscriber(Subscriber, State)};
         [_] ->
         [_] ->
             {error, {already_subscribed, Topic}}
             {error, {already_subscribed, Topic}}
     end.
     end.
 
 
-add_subscription_(undefined, Subscriber, Topic) ->
+add_subscription(undefined, Subscriber, Topic) ->
     ets:insert(mqtt_subscription, {Subscriber, Topic});
     ets:insert(mqtt_subscription, {Subscriber, Topic});
-add_subscription_(Share, Subscriber, Topic) ->
-    ets:insert(mqtt_subscription, {Subscriber, {Share, Topic}}).
+add_subscription(Share, Subscriber, Topic) ->
+    ets:insert(mqtt_subscription, {Subscriber, {share, Share, Topic}}).
 
 
-monitor_subpid(SubPid, State = #state{submon = PMon}) when is_pid(SubPid) ->
-    State#state{submon = PMon:monitor(SubPid)};
-monitor_subpid(_SubPid, State) ->
-    State.
+monitor_subscriber(SubPid, State = #state{submon = SubMon}) when is_pid(SubPid) ->
+    State#state{submon = SubMon:monitor(SubPid)};
+monitor_subscriber({SubId, SubPid}, State = #state{subids = SubIds, submon = SubMon}) ->
+    State#state{subids = maps:put(SubPid, SubId, SubIds), submon = SubMon:monitor(SubPid)}.
 
 
-do_unsubscribe_(Topic, Subscriber, State) ->
+do_unsubscribe(Topic, Subscriber, State) ->
     case ets:lookup(mqtt_subproperty, {Topic, Subscriber}) of
     case ets:lookup(mqtt_subproperty, {Topic, Subscriber}) of
         [{_, Options}] ->
         [{_, Options}] ->
             emqttd_pubsub:async_unsubscribe(Topic, Subscriber, Options),
             emqttd_pubsub:async_unsubscribe(Topic, Subscriber, Options),
             Share = proplists:get_value(share, Options),
             Share = proplists:get_value(share, Options),
-            del_subscription_(Share, Subscriber, Topic),
+            del_subscription(Share, Subscriber, Topic),
             ets:delete(mqtt_subproperty, {Topic, Subscriber}),
             ets:delete(mqtt_subproperty, {Topic, Subscriber}),
-            {ok, case ets:member(mqtt_subscription, Subscriber) of
-                true  -> State;
-                false -> demonitor_subpid(Subscriber, State)
-            end};
+            {ok, State};
         [] ->
         [] ->
             {error, {subscription_not_found, Topic}}
             {error, {subscription_not_found, Topic}}
     end.
     end.
 
 
-del_subscription_(undefined, Subscriber, Topic) ->
+del_subscription(undefined, Subscriber, Topic) ->
     ets:delete_object(mqtt_subscription, {Subscriber, Topic});
     ets:delete_object(mqtt_subscription, {Subscriber, Topic});
-del_subscription_(Share, Subscriber, Topic) ->
-    ets:delete_object(mqtt_subscription, {Subscriber, {Share, Topic}}).
+del_subscription(Share, Subscriber, Topic) ->
+    ets:delete_object(mqtt_subscription, {Subscriber, {share, Share, Topic}}).
 
 
-demonitor_subpid(SubPid, State = #state{submon = PMon}) when is_pid(SubPid) ->
-    State#state{submon = PMon:demonitor(SubPid)};
-demonitor_subpid(_SubPid, State) ->
-    State.
-
-subscriber_down_(Subscriber) ->
-    lists:foreach(fun({_, {Share, Topic}}) ->
-                        subscriber_down_(Share, Subscriber, Topic);
+clean_subscriber(Subscriber) ->
+    lists:foreach(fun({_, {share, Share, Topic}}) ->
+                      clean_subscriber(Share, Subscriber, Topic);
                      ({_, Topic}) ->
                      ({_, Topic}) ->
-                        subscriber_down_(undefined, Subscriber, Topic)
+                      clean_subscriber(undefined, Subscriber, Topic)
         end, ets:lookup(mqtt_subscription, Subscriber)),
         end, ets:lookup(mqtt_subscription, Subscriber)),
     ets:delete(mqtt_subscription, Subscriber).
     ets:delete(mqtt_subscription, Subscriber).
 
 
-subscriber_down_(Share, Subscriber, Topic) ->
+clean_subscriber(Share, Subscriber, Topic) ->
     case ets:lookup(mqtt_subproperty, {Topic, Subscriber}) of
     case ets:lookup(mqtt_subproperty, {Topic, Subscriber}) of
         [] ->
         [] ->
             %% TODO:....???
             %% TODO:....???
@@ -300,7 +313,16 @@ subscriber_down_(Share, Subscriber, Topic) ->
             ets:delete(mqtt_subproperty, {Topic, Subscriber})
             ets:delete(mqtt_subproperty, {Topic, Subscriber})
     end.
     end.
 
 
+demonitor_subscriber(SubPid, State = #state{subids = SubIds, submon = SubMon}) ->
+    State#state{subids = maps:remove(SubPid, SubIds), submon = SubMon:demonitor(SubPid)}.
+
 setstats(State) ->
 setstats(State) ->
     emqttd_stats:setstats('subscriptions/count', 'subscriptions/max',
     emqttd_stats:setstats('subscriptions/count', 'subscriptions/max',
                           ets:info(mqtt_subscription, size)), State.
                           ets:info(mqtt_subscription, size)), State.
 
 
+reply(Reply, State) ->
+    {reply, Reply, State, hibernate}.
+
+noreply(State) ->
+    {noreply, State, hibernate}.
+

+ 4 - 4
src/emqttd_session.erl

@@ -172,7 +172,7 @@
                         "Session(~s): " ++ Format, [State#state.client_id | Args])).
                         "Session(~s): " ++ Format, [State#state.client_id | Args])).
 
 
 %% @doc Start a Session
 %% @doc Start a Session
--spec(start_link(boolean(), {mqtt_client_id(), mqtt_username()}, pid()) -> {ok, pid()} | {error, any()}).
+-spec(start_link(boolean(), {mqtt_client_id(), mqtt_username()}, pid()) -> {ok, pid()} | {error, term()}).
 start_link(CleanSess, {ClientId, Username}, ClientPid) ->
 start_link(CleanSess, {ClientId, Username}, ClientPid) ->
     gen_server2:start_link(?MODULE, [CleanSess, {ClientId, Username}, ClientPid], []).
     gen_server2:start_link(?MODULE, [CleanSess, {ClientId, Username}, ClientPid], []).
 
 
@@ -192,7 +192,7 @@ subscribe(Session, PacketId, TopicTable) -> %%TODO: the ack function??...
     gen_server2:cast(Session, {subscribe, From, TopicTable, AckFun}).
     gen_server2:cast(Session, {subscribe, From, TopicTable, AckFun}).
 
 
 %% @doc Publish Message
 %% @doc Publish Message
--spec(publish(pid(), mqtt_message()) -> ok | {error, any()}).
+-spec(publish(pid(), mqtt_message()) -> ok | {error, term()}).
 publish(_Session, Msg = #mqtt_message{qos = ?QOS_0}) ->
 publish(_Session, Msg = #mqtt_message{qos = ?QOS_0}) ->
     %% Publish QoS0 Directly
     %% Publish QoS0 Directly
     emqttd_server:publish(Msg), ok;
     emqttd_server:publish(Msg), ok;
@@ -582,9 +582,9 @@ handle_info(Info, Session) ->
     ?UNEXPECTED_INFO(Info, Session).
     ?UNEXPECTED_INFO(Info, Session).
 
 
 terminate(Reason, #state{client_id = ClientId, username = Username}) ->
 terminate(Reason, #state{client_id = ClientId, username = Username}) ->
-    emqttd_stats:del_session_stats(ClientId),
+    %% Move to emqttd_sm to avoid race condition
+    %%emqttd_stats:del_session_stats(ClientId),
     emqttd_hooks:run('session.terminated', [ClientId, Username, Reason]),
     emqttd_hooks:run('session.terminated', [ClientId, Username, Reason]),
-    emqttd_server:subscriber_down(ClientId),
     emqttd_sm:unregister_session(ClientId).
     emqttd_sm:unregister_session(ClientId).
 
 
 code_change(_OldVsn, Session, _Extra) ->
 code_change(_OldVsn, Session, _Extra) ->

+ 3 - 3
src/emqttd_sm.erl

@@ -76,12 +76,12 @@ mnesia(copy) ->
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 
 
 %% @doc Start a session manager
 %% @doc Start a session manager
--spec(start_link(atom(), pos_integer()) -> {ok, pid()} | ignore | {error, any()}).
+-spec(start_link(atom(), pos_integer()) -> {ok, pid()} | ignore | {error, term()}).
 start_link(Pool, Id) ->
 start_link(Pool, Id) ->
     gen_server2:start_link({local, ?PROC_NAME(?MODULE, Id)}, ?MODULE, [Pool, Id], []).
     gen_server2:start_link({local, ?PROC_NAME(?MODULE, Id)}, ?MODULE, [Pool, Id], []).
 
 
 %% @doc Start a session
 %% @doc Start a session
--spec(start_session(boolean(), {binary(), binary() | undefined}) -> {ok, pid(), boolean()} | {error, any()}).
+-spec(start_session(boolean(), {binary(), binary() | undefined}) -> {ok, pid(), boolean()} | {error, term()}).
 start_session(CleanSess, {ClientId, Username}) ->
 start_session(CleanSess, {ClientId, Username}) ->
     SM = gproc_pool:pick_worker(?POOL, ClientId),
     SM = gproc_pool:pick_worker(?POOL, ClientId),
     call(SM, {start_session, CleanSess, {ClientId, Username}, self()}).
     call(SM, {start_session, CleanSess, {ClientId, Username}, self()}).
@@ -107,6 +107,7 @@ unregister_session(ClientId) ->
 unregister_session(ClientId, Pid) ->
 unregister_session(ClientId, Pid) ->
     case ets:lookup(mqtt_local_session, ClientId) of
     case ets:lookup(mqtt_local_session, ClientId) of
         [LocalSess = {_, Pid, _, _}] ->
         [LocalSess = {_, Pid, _, _}] ->
+            emqttd_stats:del_session_stats(ClientId),
             ets:delete_object(mqtt_local_session, LocalSess);
             ets:delete_object(mqtt_local_session, LocalSess);
         _ ->
         _ ->
             false
             false
@@ -187,7 +188,6 @@ handle_info({'DOWN', MRef, process, DownPid, _Reason}, State) ->
                     [] ->
                     [] ->
                         ok;
                         ok;
                     [Sess = #mqtt_session{sess_pid = DownPid}] ->
                     [Sess = #mqtt_session{sess_pid = DownPid}] ->
-                        emqttd_stats:del_session_stats(ClientId),
                         mnesia:delete_object(mqtt_session, Sess, write);
                         mnesia:delete_object(mqtt_session, Sess, write);
                     [_Sess] ->
                     [_Sess] ->
                         ok
                         ok

+ 1 - 1
src/emqttd_sm_helper.erl

@@ -39,7 +39,7 @@
 -define(LOCK, {?MODULE, clean_sessions}).
 -define(LOCK, {?MODULE, clean_sessions}).
 
 
 %% @doc Start a session helper
 %% @doc Start a session helper
--spec(start_link(fun()) -> {ok, pid()} | ignore | {error, any()}).
+-spec(start_link(fun()) -> {ok, pid()} | ignore | {error, term()}).
 start_link(StatsFun) ->
 start_link(StatsFun) ->
     gen_server:start_link({local, ?MODULE}, ?MODULE, [StatsFun], []).
     gen_server:start_link({local, ?MODULE}, ?MODULE, [StatsFun], []).
 
 

+ 2 - 2
src/emqttd_trace.erl

@@ -47,7 +47,7 @@ start_link() ->
     gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
     gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
 
 
 %% @doc Start to trace client or topic.
 %% @doc Start to trace client or topic.
--spec(start_trace(trace_who(), string()) -> ok | {error, any()}).
+-spec(start_trace(trace_who(), string()) -> ok | {error, term()}).
 start_trace({client, ClientId}, LogFile) ->
 start_trace({client, ClientId}, LogFile) ->
     start_trace({start_trace, {client, ClientId}, LogFile});
     start_trace({start_trace, {client, ClientId}, LogFile});
 
 
@@ -57,7 +57,7 @@ start_trace({topic, Topic}, LogFile) ->
 start_trace(Req) -> gen_server:call(?MODULE, Req, infinity).
 start_trace(Req) -> gen_server:call(?MODULE, Req, infinity).
 
 
 %% @doc Stop tracing client or topic.
 %% @doc Stop tracing client or topic.
--spec(stop_trace(trace_who()) -> ok | {error, any()}).
+-spec(stop_trace(trace_who()) -> ok | {error, term()}).
 stop_trace({client, ClientId}) ->
 stop_trace({client, ClientId}) ->
     gen_server:call(?MODULE, {stop_trace, {client, ClientId}});
     gen_server:call(?MODULE, {stop_trace, {client, ClientId}});
 stop_trace({topic, Topic}) ->
 stop_trace({topic, Topic}) ->

+ 48 - 48
src/emqttd_trie.erl

@@ -31,7 +31,7 @@
 -copy_mnesia({mnesia, [copy]}).
 -copy_mnesia({mnesia, [copy]}).
 
 
 %% Trie API
 %% Trie API
--export([insert/1, match/1, delete/1, lookup/1]).
+-export([insert/1, match/1, lookup/1, delete/1]).
 
 
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
 %% Mnesia Callbacks
 %% Mnesia Callbacks
@@ -65,22 +65,22 @@ mnesia(copy) ->
 -spec(insert(Topic :: binary()) -> ok).
 -spec(insert(Topic :: binary()) -> ok).
 insert(Topic) when is_binary(Topic) ->
 insert(Topic) when is_binary(Topic) ->
     case mnesia:read(mqtt_trie_node, Topic) of
     case mnesia:read(mqtt_trie_node, Topic) of
-    [#trie_node{topic=Topic}] ->
-        ok;
-    [TrieNode=#trie_node{topic=undefined}] ->
-        write_trie_node(TrieNode#trie_node{topic=Topic});
-    [] ->
-        % Add trie path
-        lists:foreach(fun add_path/1, emqttd_topic:triples(Topic)),
-        % Add last node
-        write_trie_node(#trie_node{node_id=Topic, topic=Topic})
+        [#trie_node{topic = Topic}] ->
+            ok;
+        [TrieNode = #trie_node{topic = undefined}] ->
+            write_trie_node(TrieNode#trie_node{topic = Topic});
+        [] ->
+            % Add trie path
+            lists:foreach(fun add_path/1, emqttd_topic:triples(Topic)),
+            % Add last node
+            write_trie_node(#trie_node{node_id = Topic, topic = Topic})
     end.
     end.
 
 
 %% @doc Find trie nodes that match topic
 %% @doc Find trie nodes that match topic
 -spec(match(Topic :: binary()) -> list(MatchedTopic :: binary())).
 -spec(match(Topic :: binary()) -> list(MatchedTopic :: binary())).
 match(Topic) when is_binary(Topic) ->
 match(Topic) when is_binary(Topic) ->
     TrieNodes = match_node(root, emqttd_topic:words(Topic)),
     TrieNodes = match_node(root, emqttd_topic:words(Topic)),
-    [Name || #trie_node{topic=Name} <- TrieNodes, Name =/= undefined].
+    [Name || #trie_node{topic = Name} <- TrieNodes, Name =/= undefined].
 
 
 %% @doc Lookup a Trie Node
 %% @doc Lookup a Trie Node
 -spec(lookup(NodeId :: binary()) -> [#trie_node{}]).
 -spec(lookup(NodeId :: binary()) -> [#trie_node{}]).
@@ -91,13 +91,13 @@ lookup(NodeId) ->
 -spec(delete(Topic :: binary()) -> ok).
 -spec(delete(Topic :: binary()) -> ok).
 delete(Topic) when is_binary(Topic) ->
 delete(Topic) when is_binary(Topic) ->
     case mnesia:read(mqtt_trie_node, Topic) of
     case mnesia:read(mqtt_trie_node, Topic) of
-    [#trie_node{edge_count=0}] ->
-        mnesia:delete({mqtt_trie_node, Topic}),
-        delete_path(lists:reverse(emqttd_topic:triples(Topic)));
-    [TrieNode] ->
-        write_trie_node(TrieNode#trie_node{topic = undefined});
-    [] ->
-        ok
+        [#trie_node{edge_count = 0}] ->
+            mnesia:delete({mqtt_trie_node, Topic}),
+            delete_path(lists:reverse(emqttd_topic:triples(Topic)));
+        [TrieNode] ->
+            write_trie_node(TrieNode#trie_node{topic = undefined});
+        [] ->
+            ok
     end.
     end.
 
 
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
@@ -107,19 +107,19 @@ delete(Topic) when is_binary(Topic) ->
 %% @private
 %% @private
 %% @doc Add path to trie tree.
 %% @doc Add path to trie tree.
 add_path({Node, Word, Child}) ->
 add_path({Node, Word, Child}) ->
-    Edge = #trie_edge{node_id=Node, word=Word},
+    Edge = #trie_edge{node_id = Node, word = Word},
     case mnesia:read(mqtt_trie_node, Node) of
     case mnesia:read(mqtt_trie_node, Node) of
-    [TrieNode = #trie_node{edge_count=Count}] ->
-        case mnesia:wread({mqtt_trie, Edge}) of
-        [] -> 
-            write_trie_node(TrieNode#trie_node{edge_count=Count+1}),
-            write_trie(#trie{edge=Edge, node_id=Child});
-        [_] -> 
-            ok
-        end;
-    [] ->
-        write_trie_node(#trie_node{node_id=Node, edge_count=1}),
-        write_trie(#trie{edge=Edge, node_id=Child})
+        [TrieNode = #trie_node{edge_count = Count}] ->
+            case mnesia:wread({mqtt_trie, Edge}) of
+                [] ->
+                    write_trie_node(TrieNode#trie_node{edge_count = Count+1}),
+                    write_trie(#trie{edge = Edge, node_id = Child});
+                [_] ->
+                    ok
+            end;
+        [] ->
+            write_trie_node(#trie_node{node_id = Node, edge_count = 1}),
+            write_trie(#trie{edge = Edge, node_id = Child})
     end.
     end.
 
 
 %% @private
 %% @private
@@ -135,20 +135,20 @@ match_node(NodeId, [], ResAcc) ->
 
 
 match_node(NodeId, [W|Words], ResAcc) ->
 match_node(NodeId, [W|Words], ResAcc) ->
     lists:foldl(fun(WArg, Acc) ->
     lists:foldl(fun(WArg, Acc) ->
-        case mnesia:read(mqtt_trie, #trie_edge{node_id=NodeId, word=WArg}) of
-        [#trie{node_id=ChildId}] -> match_node(ChildId, Words, Acc);
-        [] -> Acc
+        case mnesia:read(mqtt_trie, #trie_edge{node_id = NodeId, word = WArg}) of
+            [#trie{node_id = ChildId}] -> match_node(ChildId, Words, Acc);
+            [] -> Acc
         end
         end
     end, 'match_#'(NodeId, ResAcc), [W, '+']).
     end, 'match_#'(NodeId, ResAcc), [W, '+']).
 
 
 %% @private
 %% @private
 %% @doc Match node with '#'.
 %% @doc Match node with '#'.
 'match_#'(NodeId, ResAcc) ->
 'match_#'(NodeId, ResAcc) ->
-    case mnesia:read(mqtt_trie, #trie_edge{node_id=NodeId, word = '#'}) of
-    [#trie{node_id=ChildId}] ->
-        mnesia:read(mqtt_trie_node, ChildId) ++ ResAcc;
-    [] ->
-        ResAcc
+    case mnesia:read(mqtt_trie, #trie_edge{node_id = NodeId, word = '#'}) of
+        [#trie{node_id = ChildId}] ->
+            mnesia:read(mqtt_trie_node, ChildId) ++ ResAcc;
+        [] ->
+            ResAcc
     end.
     end.
 
 
 %% @private
 %% @private
@@ -156,17 +156,17 @@ match_node(NodeId, [W|Words], ResAcc) ->
 delete_path([]) ->
 delete_path([]) ->
     ok;
     ok;
 delete_path([{NodeId, Word, _} | RestPath]) ->
 delete_path([{NodeId, Word, _} | RestPath]) ->
-    mnesia:delete({mqtt_trie, #trie_edge{node_id=NodeId, word=Word}}),
+    mnesia:delete({mqtt_trie, #trie_edge{node_id = NodeId, word = Word}}),
     case mnesia:read(mqtt_trie_node, NodeId) of
     case mnesia:read(mqtt_trie_node, NodeId) of
-    [#trie_node{edge_count=1, topic=undefined}] -> 
-        mnesia:delete({mqtt_trie_node, NodeId}),
-        delete_path(RestPath);
-    [TrieNode=#trie_node{edge_count=1, topic=_}] -> 
-        write_trie_node(TrieNode#trie_node{edge_count=0});
-    [TrieNode=#trie_node{edge_count=C}] ->
-        write_trie_node(TrieNode#trie_node{edge_count=C-1});
-    [] ->
-        throw({notfound, NodeId}) 
+        [#trie_node{edge_count = 1, topic = undefined}] ->
+            mnesia:delete({mqtt_trie_node, NodeId}),
+            delete_path(RestPath);
+        [TrieNode = #trie_node{edge_count = 1, topic = _}] ->
+            write_trie_node(TrieNode#trie_node{edge_count = 0});
+        [TrieNode = #trie_node{edge_count = C}] ->
+            write_trie_node(TrieNode#trie_node{edge_count = C-1});
+        [] ->
+            mnesia:abort({node_not_found, NodeId})
     end.
     end.
 
 
 %% @private
 %% @private

+ 14 - 11
test/emqttd_SUITE.erl

@@ -240,8 +240,10 @@ t_local_subscribe(_) ->
     emqttd:subscribe("$local/topic2", <<"x">>, [{qos, 2}]),
     emqttd:subscribe("$local/topic2", <<"x">>, [{qos, 2}]),
     timer:sleep(10),
     timer:sleep(10),
     ?assertEqual([self()], emqttd:subscribers("$local/topic0")),
     ?assertEqual([self()], emqttd:subscribers("$local/topic0")),
-    ?assertEqual([<<"x">>], emqttd:subscribers("$local/topic1")),
-    ?assertEqual([{<<"$local/topic1">>,<<"x">>,[]},{<<"$local/topic2">>,<<"x">>,[{qos,2}]}], emqttd:subscriptions(<<"x">>)),
+    ?assertEqual([{<<"x">>, self()}], emqttd:subscribers("$local/topic1")),
+    ?assertEqual([{{<<"x">>, self()}, <<"$local/topic1">>, []},
+                  {{<<"x">>, self()}, <<"$local/topic2">>, [{qos,2}]}],
+                 emqttd:subscriptions(<<"x">>)),
     
     
     ?assertEqual(ok, emqttd:unsubscribe("$local/topic0")),
     ?assertEqual(ok, emqttd:unsubscribe("$local/topic0")),
     ?assertMatch({error, {subscription_not_found, _}}, emqttd:unsubscribe("$local/topic0")),
     ?assertMatch({error, {subscription_not_found, _}}, emqttd:unsubscribe("$local/topic0")),
@@ -256,9 +258,9 @@ t_shared_subscribe(_) ->
     emqttd:subscribe("$queue/topic3"),
     emqttd:subscribe("$queue/topic3"),
     timer:sleep(10),
     timer:sleep(10),
     ?assertEqual([self()], emqttd:subscribers(<<"$local/$share/group1/topic1">>)),
     ?assertEqual([self()], emqttd:subscribers(<<"$local/$share/group1/topic1">>)),
-    ?assertEqual([{<<"$local/$share/group1/topic1">>, self(), []},
-                  {<<"$queue/topic3">>, self(), []},
-                  {<<"$share/group2/topic2">>, self(), []}],
+    ?assertEqual([{self(), <<"$local/$share/group1/topic1">>, []},
+                  {self(), <<"$queue/topic3">>, []},
+                  {self(), <<"$share/group2/topic2">>, []}],
                  lists:sort(emqttd:subscriptions(self()))),
                  lists:sort(emqttd:subscriptions(self()))),
     emqttd:unsubscribe("$local/$share/group1/topic1"),
     emqttd:unsubscribe("$local/$share/group1/topic1"),
     emqttd:unsubscribe("$share/group2/topic2"),
     emqttd:unsubscribe("$share/group2/topic2"),
@@ -298,7 +300,7 @@ router_add_del(_) ->
     %% Add
     %% Add
     emqttd_router:add_route(<<"#">>),
     emqttd_router:add_route(<<"#">>),
     emqttd_router:add_route(<<"a/b/c">>),
     emqttd_router:add_route(<<"a/b/c">>),
-    emqttd_router:add_route(<<"+/#">>, node()),
+    emqttd_router:add_route(<<"+/#">>),
     Routes = [R1, R2 | _] = [
     Routes = [R1, R2 | _] = [
             #mqtt_route{topic = <<"#">>,     node = node()},
             #mqtt_route{topic = <<"#">>,     node = node()},
             #mqtt_route{topic = <<"+/#">>,   node = node()},
             #mqtt_route{topic = <<"+/#">>,   node = node()},
@@ -306,7 +308,7 @@ router_add_del(_) ->
     Routes = lists:sort(emqttd_router:match(<<"a/b/c">>)),
     Routes = lists:sort(emqttd_router:match(<<"a/b/c">>)),
 
 
     %% Batch Add
     %% Batch Add
-    emqttd_router:add_routes(Routes),
+    lists:foreach(fun(R) -> emqttd_router:add_route(R) end, Routes),
     Routes = lists:sort(emqttd_router:match(<<"a/b/c">>)),
     Routes = lists:sort(emqttd_router:match(<<"a/b/c">>)),
 
 
     %% Del
     %% Del
@@ -317,7 +319,8 @@ router_add_del(_) ->
     %% Batch Del
     %% Batch Del
     R3 = #mqtt_route{topic = <<"#">>, node = 'a@127.0.0.1'},
     R3 = #mqtt_route{topic = <<"#">>, node = 'a@127.0.0.1'},
     emqttd_router:add_route(R3),
     emqttd_router:add_route(R3),
-    emqttd_router:del_routes([R1, R2]),
+    emqttd_router:del_route(R1),
+    emqttd_router:del_route(R2),
     emqttd_router:del_route(R3),
     emqttd_router:del_route(R3),
     [] = lists:sort(emqttd_router:match(<<"a/b/c">>)).
     [] = lists:sort(emqttd_router:match(<<"a/b/c">>)).
 
 
@@ -325,7 +328,7 @@ router_print(_) ->
     Routes = [#mqtt_route{topic = <<"a/b/c">>, node = node()},
     Routes = [#mqtt_route{topic = <<"a/b/c">>, node = node()},
               #mqtt_route{topic = <<"#">>,     node = node()},
               #mqtt_route{topic = <<"#">>,     node = node()},
               #mqtt_route{topic = <<"+/#">>,   node = node()}],
               #mqtt_route{topic = <<"+/#">>,   node = node()}],
-    emqttd_router:add_routes(Routes),
+    lists:foreach(fun(R) -> emqttd_router:add_route(R) end, Routes),
     emqttd_router:print(<<"a/b/c">>).
     emqttd_router:print(<<"a/b/c">>).
 
 
 router_unused(_) ->
 router_unused(_) ->
@@ -589,9 +592,9 @@ conflict_listeners(_) ->
                {current_clients, esockd:get_current_clients(Pid)},
                {current_clients, esockd:get_current_clients(Pid)},
                {shutdown_count, esockd:get_shutdown_count(Pid)}]}
                {shutdown_count, esockd:get_shutdown_count(Pid)}]}
               end, esockd:listeners()),
               end, esockd:listeners()),
-    L =proplists:get_value("mqtt:tcp:0.0.0.0:1883", Listeners),
+    L = proplists:get_value("mqtt:tcp:0.0.0.0:1883", Listeners),
     ?assertEqual(1, proplists:get_value(current_clients, L)),
     ?assertEqual(1, proplists:get_value(current_clients, L)),
-    ?assertEqual(1, proplists:get_value(conflict, L)),
+    ?assertEqual(1, proplists:get_value(conflict, proplists:get_value(shutdown_count, L))),
     emqttc:disconnect(C2).
     emqttc:disconnect(C2).
 
 
 cli_vm(_) ->
 cli_vm(_) ->

+ 132 - 0
test/emqttd_router_SUITE.erl

@@ -0,0 +1,132 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2013-2017 EMQ Enterprise, Inc. (http://emqtt.io)
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%%     http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+
+-module(emqttd_router_SUITE).
+
+-compile(export_all).
+
+-include("emqttd.hrl").
+
+-include_lib("eunit/include/eunit.hrl").
+
+-define(R, emqttd_router).
+
+all() ->
+    [{group, route},
+     {group, local_route}].
+
+groups() ->
+    [{route, [sequence],
+      [t_get_topics,
+       t_add_del_route,
+       t_match_route,
+       t_print,
+       t_has_route]},
+     {local_route, [sequence],
+      [t_get_local_topics,
+       t_add_del_local_route,
+       t_match_local_route]}].
+
+init_per_suite(Config) ->
+    ekka:start(),
+    ekka_mnesia:ensure_started(),
+    {ok, _R} = emqttd_router:start(),
+    Config.
+
+end_per_suite(_Config) ->
+    emqttd_router:stop(),
+    ekka:stop(),
+    ekka_mnesia:ensure_stopped(),
+    ekka_mnesia:delete_schema().
+
+init_per_testcase(_TestCase, Config) ->
+    Config.
+
+end_per_testcase(_TestCase, _Config) ->
+    clear_tables().
+
+t_get_topics(_) ->
+    ?R:add_route(<<"a/b/c">>),
+    ?R:add_route(<<"a/b/c">>),
+    ?R:add_route(<<"a/+/b">>),
+    ?assertEqual([<<"a/+/b">>, <<"a/b/c">>], lists:sort(?R:topics())),
+    ?R:del_route(<<"a/b/c">>),
+    ?R:del_route(<<"a/+/b">>),
+    ?assertEqual([], lists:sort(?R:topics())).
+
+t_add_del_route(_) ->
+    %%Node = node(),
+    ?R:add_route(<<"a/b/c">>),
+    ?R:add_route(<<"a/+/b">>),
+    ?R:del_route(<<"a/b/c">>),
+    ?R:del_route(<<"a/+/b">>).
+
+t_match_route(_) ->
+    Node = node(),
+    ?R:add_route(<<"a/b/c">>),
+    ?R:add_route(<<"a/+/c">>),
+    ?R:add_route(<<"a/b/#">>),
+    ?R:add_route(<<"#">>),
+    ?assertEqual([#mqtt_route{topic = <<"#">>, node = Node},
+                  #mqtt_route{topic = <<"a/+/c">>, node = Node},
+                  #mqtt_route{topic = <<"a/b/#">>, node = Node},
+                  #mqtt_route{topic = <<"a/b/c">>, node = Node}],
+                 lists:sort(?R:match(<<"a/b/c">>))).
+
+t_print(_) ->
+    ?R:add_route(<<"topic">>),
+    ?R:add_route(<<"topic/#">>),
+    ?R:print(<<"topic">>).
+
+t_has_route(_) ->
+    ?R:add_route(<<"devices/+/messages">>),
+    ?assert(?R:has_route(<<"devices/+/messages">>)).
+
+t_get_local_topics(_) ->
+    ?R:add_local_route(<<"a/b/c">>),
+    ?R:add_local_route(<<"x/+/y">>),
+    ?R:add_local_route(<<"z/#">>),
+    ?assertEqual([<<"z/#">>, <<"x/+/y">>, <<"a/b/c">>], ?R:local_topics()),
+    ?R:del_local_route(<<"x/+/y">>),
+    ?R:del_local_route(<<"z/#">>),
+    ?assertEqual([<<"a/b/c">>], ?R:local_topics()).
+
+t_add_del_local_route(_) ->
+    Node = node(),
+    ?R:add_local_route(<<"a/b/c">>),
+    ?R:add_local_route(<<"x/+/y">>),
+    ?R:add_local_route(<<"z/#">>),
+    ?assertEqual([{<<"a/b/c">>, Node},
+                  {<<"x/+/y">>, Node},
+                  {<<"z/#">>, Node}],
+                 lists:sort(?R:get_local_routes())),
+    ?R:del_local_route(<<"x/+/y">>),
+    ?R:del_local_route(<<"z/#">>),
+    ?assertEqual([{<<"a/b/c">>, Node}], lists:sort(?R:get_local_routes())).
+
+t_match_local_route(_) ->
+    ?R:add_local_route(<<"$SYS/#">>),
+    ?R:add_local_route(<<"a/b/c">>),
+    ?R:add_local_route(<<"a/+/c">>),
+    ?R:add_local_route(<<"a/b/#">>),
+    ?R:add_local_route(<<"#">>),
+    Matched = [Topic || #mqtt_route{topic = {local, Topic}} <- ?R:match_local(<<"a/b/c">>)],
+    ?assertEqual([<<"#">>, <<"a/+/c">>, <<"a/b/#">>, <<"a/b/c">>], lists:sort(Matched)).
+
+clear_tables() ->
+    ?R:clean_local_routes(),
+    lists:foreach(fun mnesia:clear_table/1, [mqtt_route, mqtt_trie, mqtt_trie_node]).
+