emqttd_sm.erl 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. %%--------------------------------------------------------------------
  2. %% Copyright (c) 2012-2016 Feng Lee <feng@emqtt.io>.
  3. %%
  4. %% Licensed under the Apache License, Version 2.0 (the "License");
  5. %% you may not use this file except in compliance with the License.
  6. %% You may obtain a copy of the License at
  7. %%
  8. %% http://www.apache.org/licenses/LICENSE-2.0
  9. %%
  10. %% Unless required by applicable law or agreed to in writing, software
  11. %% distributed under the License is distributed on an "AS IS" BASIS,
  12. %% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. %% See the License for the specific language governing permissions and
  14. %% limitations under the License.
  15. %%--------------------------------------------------------------------
  16. %% @doc Session Manager
  17. -module(emqttd_sm).
  18. -behaviour(gen_server2).
  19. -include("emqttd.hrl").
  20. -include("emqttd_internal.hrl").
  21. %% Mnesia Callbacks
  22. -export([mnesia/1]).
  23. -boot_mnesia({mnesia, [boot]}).
  24. -copy_mnesia({mnesia, [copy]}).
  25. %% API Function Exports
  26. -export([start_link/2]).
  27. -export([start_session/2, lookup_session/1]).
  28. -export([register_session/3, unregister_session/2]).
  29. %% gen_server Function Exports
  30. -export([init/1, handle_call/3, handle_cast/2, handle_info/2,
  31. terminate/2, code_change/3]).
  32. %% gen_server2 priorities
  33. -export([prioritise_call/4, prioritise_cast/3, prioritise_info/3]).
  34. -record(state, {pool, id, monitors}).
  35. -define(POOL, ?MODULE).
  36. -define(TIMEOUT, 120000).
  37. -define(LOG(Level, Format, Args, Session),
  38. lager:Level("SM(~s): " ++ Format, [Session#mqtt_session.client_id | Args])).
  39. %%--------------------------------------------------------------------
  40. %% Mnesia callbacks
  41. %%--------------------------------------------------------------------
  42. mnesia(boot) ->
  43. %% Global Session Table
  44. ok = emqttd_mnesia:create_table(session, [
  45. {type, set},
  46. {ram_copies, [node()]},
  47. {record_name, mqtt_session},
  48. {attributes, record_info(fields, mqtt_session)}]);
  49. mnesia(copy) ->
  50. ok = emqttd_mnesia:copy_table(session).
  51. %%--------------------------------------------------------------------
  52. %% API
  53. %%--------------------------------------------------------------------
  54. %% @doc Start a session manager
  55. -spec start_link(atom(), pos_integer()) -> {ok, pid()} | ignore | {error, any()}.
  56. start_link(Pool, Id) ->
  57. gen_server2:start_link({local, ?PROC_NAME(?MODULE, Id)}, ?MODULE, [Pool, Id], []).
  58. %% @doc Start a session
  59. -spec start_session(CleanSess :: boolean(), binary()) -> {ok, pid(), boolean()} | {error, any()}.
  60. start_session(CleanSess, ClientId) ->
  61. SM = gproc_pool:pick_worker(?POOL, ClientId),
  62. call(SM, {start_session, {CleanSess, ClientId, self()}}).
  63. %% @doc Lookup a Session
  64. -spec lookup_session(binary()) -> mqtt_session() | undefined.
  65. lookup_session(ClientId) ->
  66. case mnesia:dirty_read(session, ClientId) of
  67. [Session] -> Session;
  68. [] -> undefined
  69. end.
  70. %% @doc Register a session with info.
  71. -spec register_session(CleanSess, ClientId, Info) -> ok when
  72. CleanSess :: boolean(),
  73. ClientId :: binary(),
  74. Info :: [tuple()].
  75. register_session(CleanSess, ClientId, Info) ->
  76. ets:insert(sesstab(CleanSess), {{ClientId, self()}, Info}).
  77. %% @doc Unregister a session.
  78. -spec unregister_session(CleanSess, ClientId) -> ok when
  79. CleanSess :: boolean(),
  80. ClientId :: binary().
  81. unregister_session(CleanSess, ClientId) ->
  82. ets:delete(sesstab(CleanSess), {ClientId, self()}).
  83. sesstab(true) -> mqtt_transient_session;
  84. sesstab(false) -> mqtt_persistent_session.
  85. call(SM, Req) ->
  86. gen_server2:call(SM, Req, ?TIMEOUT). %%infinity).
  87. %%--------------------------------------------------------------------
  88. %% gen_server callbacks
  89. %%--------------------------------------------------------------------
  90. init([Pool, Id]) ->
  91. ?GPROC_POOL(join, Pool, Id),
  92. {ok, #state{pool = Pool, id = Id,
  93. monitors = dict:new()}}.
  94. prioritise_call(_Msg, _From, _Len, _State) ->
  95. 1.
  96. prioritise_cast(_Msg, _Len, _State) ->
  97. 0.
  98. prioritise_info(_Msg, _Len, _State) ->
  99. 2.
  100. %% Persistent Session
  101. handle_call({start_session, Client = {false, ClientId, ClientPid}}, _From, State) ->
  102. case lookup_session(ClientId) of
  103. undefined ->
  104. %% Create session locally
  105. create_session(Client, State);
  106. Session ->
  107. case resume_session(Session, ClientPid) of
  108. {ok, SessPid} ->
  109. {reply, {ok, SessPid, true}, State};
  110. {error, Erorr} ->
  111. {reply, {error, Erorr}, State}
  112. end
  113. end;
  114. %% Transient Session
  115. handle_call({start_session, Client = {true, ClientId, _ClientPid}}, _From, State) ->
  116. case lookup_session(ClientId) of
  117. undefined ->
  118. create_session(Client, State);
  119. Session ->
  120. case destroy_session(Session) of
  121. ok ->
  122. create_session(Client, State);
  123. {error, Error} ->
  124. {reply, {error, Error}, State}
  125. end
  126. end;
  127. handle_call(Req, _From, State) ->
  128. ?UNEXPECTED_REQ(Req, State).
  129. handle_cast(Msg, State) ->
  130. ?UNEXPECTED_MSG(Msg, State).
  131. handle_info({'DOWN', MRef, process, DownPid, _Reason}, State) ->
  132. case dict:find(MRef, State#state.monitors) of
  133. {ok, ClientId} ->
  134. mnesia:transaction(fun() ->
  135. case mnesia:wread({session, ClientId}) of
  136. [] -> ok;
  137. [Sess = #mqtt_session{sess_pid = DownPid}] ->
  138. mnesia:delete_object(session, Sess, write);
  139. [_Sess] -> ok
  140. end
  141. end),
  142. {noreply, erase_monitor(MRef, State)};
  143. error ->
  144. lager:error("MRef of session ~p not found", [DownPid]),
  145. {noreply, State}
  146. end;
  147. handle_info(Info, State) ->
  148. ?UNEXPECTED_INFO(Info, State).
  149. terminate(_Reason, #state{pool = Pool, id = Id}) ->
  150. ?GPROC_POOL(leave, Pool, Id).
  151. code_change(_OldVsn, State, _Extra) ->
  152. {ok, State}.
  153. %%--------------------------------------------------------------------
  154. %% Internal functions
  155. %%--------------------------------------------------------------------
  156. %% Create Session Locally
  157. create_session({CleanSess, ClientId, ClientPid}, State) ->
  158. case create_session(CleanSess, ClientId, ClientPid) of
  159. {ok, SessPid} ->
  160. {reply, {ok, SessPid, false},
  161. monitor_session(ClientId, SessPid, State)};
  162. {error, Error} ->
  163. {reply, {error, Error}, State}
  164. end.
  165. create_session(CleanSess, ClientId, ClientPid) ->
  166. case emqttd_session_sup:start_session(CleanSess, ClientId, ClientPid) of
  167. {ok, SessPid} ->
  168. Session = #mqtt_session{client_id = ClientId,
  169. sess_pid = SessPid,
  170. persistent = not CleanSess},
  171. case insert_session(Session) of
  172. {aborted, {conflict, ConflictPid}} ->
  173. %% Conflict with othe node?
  174. lager:error("SM(~s): Conflict with ~p", [ClientId, ConflictPid]),
  175. {error, mnesia_conflict};
  176. {atomic, ok} ->
  177. {ok, SessPid}
  178. end;
  179. {error, Error} ->
  180. {error, Error}
  181. end.
  182. insert_session(Session = #mqtt_session{client_id = ClientId}) ->
  183. mnesia:transaction(
  184. fun() ->
  185. case mnesia:wread({session, ClientId}) of
  186. [] ->
  187. mnesia:write(session, Session, write);
  188. [#mqtt_session{sess_pid = SessPid}] ->
  189. mnesia:abort({conflict, SessPid})
  190. end
  191. end).
  192. %% Local node
  193. resume_session(Session = #mqtt_session{client_id = ClientId,
  194. sess_pid = SessPid}, ClientPid)
  195. when node(SessPid) =:= node() ->
  196. case is_process_alive(SessPid) of
  197. true ->
  198. emqttd_session:resume(SessPid, ClientId, ClientPid),
  199. {ok, SessPid};
  200. false ->
  201. ?LOG(error, "Cannot resume ~p which seems already dead!", [SessPid], Session),
  202. {error, session_died}
  203. end;
  204. %% Remote node
  205. resume_session(Session = #mqtt_session{client_id = ClientId, sess_pid = SessPid}, ClientPid) ->
  206. Node = node(SessPid),
  207. case rpc:call(Node, emqttd_session, resume, [SessPid, ClientId, ClientPid]) of
  208. ok ->
  209. {ok, SessPid};
  210. {badrpc, nodedown} ->
  211. ?LOG(error, "Session died for node '~s' down", [Node], Session),
  212. remove_session(Session),
  213. {error, session_nodedown};
  214. {badrpc, Reason} ->
  215. ?LOG(error, "Failed to resume from node ~s for ~p", [Node, Reason], Session),
  216. {error, Reason}
  217. end.
  218. %% Local node
  219. destroy_session(Session = #mqtt_session{client_id = ClientId, sess_pid = SessPid})
  220. when node(SessPid) =:= node() ->
  221. emqttd_session:destroy(SessPid, ClientId),
  222. remove_session(Session);
  223. %% Remote node
  224. destroy_session(Session = #mqtt_session{client_id = ClientId,
  225. sess_pid = SessPid}) ->
  226. Node = node(SessPid),
  227. case rpc:call(Node, emqttd_session, destroy, [SessPid, ClientId]) of
  228. ok ->
  229. remove_session(Session);
  230. {badrpc, nodedown} ->
  231. ?LOG(error, "Node '~s' down", [Node], Session),
  232. remove_session(Session);
  233. {badrpc, Reason} ->
  234. ?LOG(error, "Failed to destory ~p on remote node ~p for ~s",
  235. [SessPid, Node, Reason], Session),
  236. {error, Reason}
  237. end.
  238. remove_session(Session) ->
  239. case mnesia:transaction(fun mnesia:delete_object/3, [session, Session, write]) of
  240. {atomic, ok} -> ok;
  241. {aborted, Error} -> {error, Error}
  242. end.
  243. monitor_session(ClientId, SessPid, State = #state{monitors = Monitors}) ->
  244. MRef = erlang:monitor(process, SessPid),
  245. State#state{monitors = dict:store(MRef, ClientId, Monitors)}.
  246. erase_monitor(MRef, State = #state{monitors = Monitors}) ->
  247. State#state{monitors = dict:erase(MRef, Monitors)}.