emqttd_sm.erl 10 KB


  1. %%--------------------------------------------------------------------
  2. %% Copyright (c) 2012-2016 Feng Lee <feng@emqtt.io>.
  3. %%
  4. %% Licensed under the Apache License, Version 2.0 (the "License");
  5. %% you may not use this file except in compliance with the License.
  6. %% You may obtain a copy of the License at
  7. %%
  8. %% http://www.apache.org/licenses/LICENSE-2.0
  9. %%
  10. %% Unless required by applicable law or agreed to in writing, software
  11. %% distributed under the License is distributed on an "AS IS" BASIS,
  12. %% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. %% See the License for the specific language governing permissions and
  14. %% limitations under the License.
  15. %%--------------------------------------------------------------------
  16. %% @doc Session Manager
  17. -module(emqttd_sm).
  18. -behaviour(gen_server2).
  19. -include("emqttd.hrl").
  20. -include("emqttd_internal.hrl").
  21. %% Mnesia Callbacks
  22. -export([mnesia/1]).
  23. -boot_mnesia({mnesia, [boot]}).
  24. -copy_mnesia({mnesia, [copy]}).
  25. %% API Function Exports
  26. -export([start_link/2]).
  27. -export([start_session/2, lookup_session/1, reg_session/3, unreg_session/1]).
  28. -export([dispatch/3]).
  29. %% gen_server Function Exports
  30. -export([init/1, handle_call/3, handle_cast/2, handle_info/2,
  31. terminate/2, code_change/3]).
  32. %% gen_server2 priorities
  33. -export([prioritise_call/4, prioritise_cast/3, prioritise_info/3]).
  34. -record(state, {pool, id, monitors}).
  35. -define(POOL, ?MODULE).
  36. -define(TIMEOUT, 120000).
  37. -define(LOG(Level, Format, Args, Session),
  38. lager:Level("SM(~s): " ++ Format, [Session#mqtt_session.client_id | Args])).
  39. %%--------------------------------------------------------------------
  40. %% Mnesia callbacks
  41. %%--------------------------------------------------------------------
  42. mnesia(boot) ->
  43. %% Global Session Table
  44. ok = emqttd_mnesia:create_table(mqtt_session, [
  45. {type, set},
  46. {ram_copies, [node()]},
  47. {record_name, mqtt_session},
  48. {attributes, record_info(fields, mqtt_session)}]);
  49. mnesia(copy) ->
  50. ok = emqttd_mnesia:copy_table(mqtt_session).
  51. %%--------------------------------------------------------------------
  52. %% API
  53. %%--------------------------------------------------------------------
  54. %% @doc Start a session manager
  55. -spec(start_link(atom(), pos_integer()) -> {ok, pid()} | ignore | {error, any()}).
  56. start_link(Pool, Id) ->
  57. gen_server2:start_link({local, ?PROC_NAME(?MODULE, Id)}, ?MODULE, [Pool, Id], []).
  58. %% @doc Start a session
  59. -spec(start_session(boolean(), {binary(), binary() | undefined}) -> {ok, pid(), boolean()} | {error, any()}).
  60. start_session(CleanSess, {ClientId, Username}) ->
  61. SM = gproc_pool:pick_worker(?POOL, ClientId),
  62. call(SM, {start_session, CleanSess, {ClientId, Username}, self()}).
  63. %% @doc Lookup a Session
  64. -spec(lookup_session(binary()) -> mqtt_session() | undefined).
  65. lookup_session(ClientId) ->
  66. case mnesia:dirty_read(mqtt_session, ClientId) of
  67. [Session] -> Session;
  68. [] -> undefined
  69. end.
  70. %% @doc Register a session with info.
  71. -spec(reg_session(binary(), boolean(), [tuple()]) -> true).
  72. reg_session(ClientId, CleanSess, Properties) ->
  73. ets:insert(mqtt_local_session, {ClientId, self(), CleanSess, Properties}).
  74. %% @doc Unregister a session.
  75. -spec(unreg_session(binary()) -> true).
  76. unreg_session(ClientId) ->
  77. ets:delete(mqtt_local_session, ClientId).
  78. dispatch(ClientId, Topic, Msg) ->
  79. try ets:lookup_element(mqtt_local_session, ClientId, 2) of
  80. Pid -> Pid ! {dispatch, Topic, Msg}
  81. catch
  82. error:badarg -> io:format("Session Not Found: ~p~n", [ClientId]), ok %%TODO: How??
  83. end.
  84. call(SM, Req) ->
  85. gen_server2:call(SM, Req, ?TIMEOUT). %%infinity).
  86. %%--------------------------------------------------------------------
  87. %% gen_server callbacks
  88. %%--------------------------------------------------------------------
  89. init([Pool, Id]) ->
  90. ?GPROC_POOL(join, Pool, Id),
  91. {ok, #state{pool = Pool, id = Id, monitors = dict:new()}}.
  92. prioritise_call(_Msg, _From, _Len, _State) ->
  93. 1.
  94. prioritise_cast(_Msg, _Len, _State) ->
  95. 0.
  96. prioritise_info(_Msg, _Len, _State) ->
  97. 2.
  98. %% Persistent Session
  99. handle_call({start_session, false, {ClientId, Username}, ClientPid}, _From, State) ->
  100. case lookup_session(ClientId) of
  101. undefined ->
  102. %% Create session locally
  103. create_session({false, {ClientId, Username}, ClientPid}, State);
  104. Session ->
  105. case resume_session(Session, ClientPid) of
  106. {ok, SessPid} ->
  107. {reply, {ok, SessPid, true}, State};
  108. {error, Erorr} ->
  109. {reply, {error, Erorr}, State}
  110. end
  111. end;
  112. %% Transient Session
  113. handle_call({start_session, true, {ClientId, Username}, ClientPid}, _From, State) ->
  114. Client = {true, {ClientId, Username}, ClientPid},
  115. case lookup_session(ClientId) of
  116. undefined ->
  117. create_session(Client, State);
  118. Session ->
  119. case destroy_session(Session) of
  120. ok ->
  121. create_session(Client, State);
  122. {error, Error} ->
  123. {reply, {error, Error}, State}
  124. end
  125. end;
  126. handle_call(Req, _From, State) ->
  127. ?UNEXPECTED_REQ(Req, State).
  128. handle_cast(Msg, State) ->
  129. ?UNEXPECTED_MSG(Msg, State).
  130. handle_info({'DOWN', MRef, process, DownPid, _Reason}, State) ->
  131. case dict:find(MRef, State#state.monitors) of
  132. {ok, ClientId} ->
  133. mnesia:transaction(fun() ->
  134. case mnesia:wread({mqtt_session, ClientId}) of
  135. [] ->
  136. ok;
  137. [Sess = #mqtt_session{sess_pid = DownPid}] ->
  138. mnesia:delete_object(mqtt_session, Sess, write);
  139. [_Sess] ->
  140. ok
  141. end
  142. end),
  143. {noreply, erase_monitor(MRef, State), hibernate};
  144. error ->
  145. lager:error("MRef of session ~p not found", [DownPid]),
  146. {noreply, State}
  147. end;
  148. handle_info(Info, State) ->
  149. ?UNEXPECTED_INFO(Info, State).
  150. terminate(_Reason, #state{pool = Pool, id = Id}) ->
  151. ?GPROC_POOL(leave, Pool, Id).
  152. code_change(_OldVsn, State, _Extra) ->
  153. {ok, State}.
  154. %%--------------------------------------------------------------------
  155. %% Internal functions
  156. %%--------------------------------------------------------------------
  157. %% Create Session Locally
  158. create_session({CleanSess, {ClientId, Username}, ClientPid}, State) ->
  159. case create_session(CleanSess, {ClientId, Username}, ClientPid) of
  160. {ok, SessPid} ->
  161. {reply, {ok, SessPid, false},
  162. monitor_session(ClientId, SessPid, State)};
  163. {error, Error} ->
  164. {reply, {error, Error}, State}
  165. end.
  166. create_session(CleanSess, {ClientId, Username}, ClientPid) ->
  167. case emqttd_session_sup:start_session(CleanSess, {ClientId, Username}, ClientPid) of
  168. {ok, SessPid} ->
  169. Session = #mqtt_session{client_id = ClientId, sess_pid = SessPid, persistent = not CleanSess},
  170. case insert_session(Session) of
  171. {aborted, {conflict, ConflictPid}} ->
  172. %% Conflict with othe node?
  173. lager:error("SM(~s): Conflict with ~p", [ClientId, ConflictPid]),
  174. {error, mnesia_conflict};
  175. {atomic, ok} ->
  176. {ok, SessPid}
  177. end;
  178. {error, Error} ->
  179. {error, Error}
  180. end.
  181. insert_session(Session = #mqtt_session{client_id = ClientId}) ->
  182. mnesia:transaction(
  183. fun() ->
  184. case mnesia:wread({mqtt_session, ClientId}) of
  185. [] ->
  186. mnesia:write(mqtt_session, Session, write);
  187. [#mqtt_session{sess_pid = SessPid}] ->
  188. mnesia:abort({conflict, SessPid})
  189. end
  190. end).
  191. %% Local node
  192. resume_session(Session = #mqtt_session{client_id = ClientId, sess_pid = SessPid}, ClientPid)
  193. when node(SessPid) =:= node() ->
  194. case is_process_alive(SessPid) of
  195. true ->
  196. emqttd_session:resume(SessPid, ClientId, ClientPid),
  197. {ok, SessPid};
  198. false ->
  199. ?LOG(error, "Cannot resume ~p which seems already dead!", [SessPid], Session),
  200. {error, session_died}
  201. end;
  202. %% Remote node
  203. resume_session(Session = #mqtt_session{client_id = ClientId, sess_pid = SessPid}, ClientPid) ->
  204. Node = node(SessPid),
  205. case rpc:call(Node, emqttd_session, resume, [SessPid, ClientId, ClientPid]) of
  206. ok ->
  207. {ok, SessPid};
  208. {badrpc, nodedown} ->
  209. ?LOG(error, "Session died for node '~s' down", [Node], Session),
  210. remove_session(Session),
  211. {error, session_nodedown};
  212. {badrpc, Reason} ->
  213. ?LOG(error, "Failed to resume from node ~s for ~p", [Node, Reason], Session),
  214. {error, Reason}
  215. end.
  216. %% Local node
  217. destroy_session(Session = #mqtt_session{client_id = ClientId, sess_pid = SessPid})
  218. when node(SessPid) =:= node() ->
  219. emqttd_session:destroy(SessPid, ClientId),
  220. remove_session(Session);
  221. %% Remote node
  222. destroy_session(Session = #mqtt_session{client_id = ClientId,
  223. sess_pid = SessPid}) ->
  224. Node = node(SessPid),
  225. case rpc:call(Node, emqttd_session, destroy, [SessPid, ClientId]) of
  226. ok ->
  227. remove_session(Session);
  228. {badrpc, nodedown} ->
  229. ?LOG(error, "Node '~s' down", [Node], Session),
  230. remove_session(Session);
  231. {badrpc, Reason} ->
  232. ?LOG(error, "Failed to destory ~p on remote node ~p for ~s",
  233. [SessPid, Node, Reason], Session),
  234. {error, Reason}
  235. end.
  236. remove_session(Session) ->
  237. case mnesia:transaction(fun mnesia:delete_object/1, [Session]) of
  238. {atomic, ok} -> ok;
  239. {aborted, Error} -> {error, Error}
  240. end.
  241. monitor_session(ClientId, SessPid, State = #state{monitors = Monitors}) ->
  242. MRef = erlang:monitor(process, SessPid),
  243. State#state{monitors = dict:store(MRef, ClientId, Monitors)}.
  244. erase_monitor(MRef, State = #state{monitors = Monitors}) ->
  245. State#state{monitors = dict:erase(MRef, Monitors)}.