|
|
@@ -33,6 +33,8 @@
|
|
|
-export_type([inflight/0, seqno/0]).
|
|
|
|
|
|
-include_lib("emqx/include/logger.hrl").
|
|
|
+-include_lib("emqx/include/emqx_mqtt.hrl").
|
|
|
+-include_lib("emqx_utils/include/emqx_message.hrl").
|
|
|
-include("emqx_persistent_session_ds.hrl").
|
|
|
|
|
|
-ifdef(TEST).
|
|
|
@@ -46,6 +48,8 @@
|
|
|
-define(COMP, 1).
|
|
|
|
|
|
-define(TRACK_FLAG(WHICH), (1 bsl WHICH)).
|
|
|
+-define(TRACK_FLAGS_ALL, ?TRACK_FLAG(?ACK) bor ?TRACK_FLAG(?COMP)).
|
|
|
+-define(TRACK_FLAGS_NONE, 0).
|
|
|
|
|
|
%%================================================================================
|
|
|
%% Type declarations
|
|
|
@@ -66,9 +70,10 @@
|
|
|
|
|
|
-opaque inflight() :: #inflight{}.
|
|
|
|
|
|
--type replies() :: reply() | [replies()].
|
|
|
--type reply() :: emqx_session:reply() | fun((emqx_types:packet_id()) -> emqx_session:replies()).
|
|
|
--type reply_fun() :: fun((seqno(), emqx_types:message()) -> replies()).
|
|
|
+-type message() :: emqx_types:message().
|
|
|
+-type replies() :: [emqx_session:reply()].
|
|
|
+
|
|
|
+-type preproc_fun() :: fun((message()) -> message() | [message()]).
|
|
|
|
|
|
%%================================================================================
|
|
|
%% API funcions
|
|
|
@@ -115,11 +120,11 @@ n_inflight(#inflight{offset_ranges = Ranges}) ->
|
|
|
Ranges
|
|
|
).
|
|
|
|
|
|
--spec replay(reply_fun(), inflight()) -> {emqx_session:replies(), inflight()}.
|
|
|
-replay(ReplyFun, Inflight0 = #inflight{offset_ranges = Ranges0}) ->
|
|
|
+-spec replay(preproc_fun(), inflight()) -> {emqx_session:replies(), inflight()}.
|
|
|
+replay(PreprocFunFun, Inflight0 = #inflight{offset_ranges = Ranges0, commits = Commits}) ->
|
|
|
{Ranges, Replies} = lists:mapfoldr(
|
|
|
fun(Range, Acc) ->
|
|
|
- replay_range(ReplyFun, Range, Acc)
|
|
|
+ replay_range(PreprocFunFun, Commits, Range, Acc)
|
|
|
end,
|
|
|
[],
|
|
|
Ranges0
|
|
|
@@ -165,9 +170,9 @@ commit_offset(
|
|
|
{false, Inflight0}
|
|
|
end.
|
|
|
|
|
|
--spec poll(reply_fun(), emqx_persistent_session_ds:id(), inflight(), pos_integer()) ->
|
|
|
+-spec poll(preproc_fun(), emqx_persistent_session_ds:id(), inflight(), pos_integer()) ->
|
|
|
{emqx_session:replies(), inflight()}.
|
|
|
-poll(ReplyFun, SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < ?EPOCH_SIZE ->
|
|
|
+poll(PreprocFun, SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize < ?EPOCH_SIZE ->
|
|
|
MinBatchSize = emqx_config:get([session_persistence, min_batch_size]),
|
|
|
FetchThreshold = min(MinBatchSize, ceil(WindowSize / 2)),
|
|
|
FreeSpace = WindowSize - n_inflight(Inflight0),
|
|
|
@@ -181,7 +186,7 @@ poll(ReplyFun, SessionId, Inflight0, WindowSize) when WindowSize > 0, WindowSize
|
|
|
true ->
|
|
|
%% TODO: Wrap this in `mria:async_dirty/2`?
|
|
|
Streams = shuffle(get_streams(SessionId)),
|
|
|
- fetch(ReplyFun, SessionId, Inflight0, Streams, FreeSpace, [])
|
|
|
+ fetch(PreprocFun, SessionId, Inflight0, Streams, FreeSpace, [])
|
|
|
end.
|
|
|
|
|
|
%% Which seqno this track is committed until.
|
|
|
@@ -248,22 +253,22 @@ get_ranges(SessionId) ->
|
|
|
),
|
|
|
mnesia:match_object(?SESSION_PUBRANGE_TAB, Pat, read).
|
|
|
|
|
|
-fetch(ReplyFun, SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 ->
|
|
|
+fetch(PreprocFun, SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 ->
|
|
|
#inflight{next_seqno = FirstSeqno, offset_ranges = Ranges} = Inflight0,
|
|
|
ItBegin = get_last_iterator(DSStream, Ranges),
|
|
|
{ok, ItEnd, Messages} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, ItBegin, N),
|
|
|
case Messages of
|
|
|
[] ->
|
|
|
- fetch(ReplyFun, SessionId, Inflight0, Streams, N, Acc);
|
|
|
+ fetch(PreprocFun, SessionId, Inflight0, Streams, N, Acc);
|
|
|
_ ->
|
|
|
%% We need to preserve the iterator pointing to the beginning of the
|
|
|
%% range, so that we can replay it if needed.
|
|
|
- {Publishes, {UntilSeqno, Tracks}} = publish(ReplyFun, FirstSeqno, Messages),
|
|
|
+ {Publishes, UntilSeqno} = publish_fetch(PreprocFun, FirstSeqno, Messages),
|
|
|
Size = range_size(FirstSeqno, UntilSeqno),
|
|
|
Range0 = #ds_pubrange{
|
|
|
id = {SessionId, FirstSeqno},
|
|
|
type = ?T_INFLIGHT,
|
|
|
- tracks = Tracks,
|
|
|
+ tracks = compute_pub_tracks(Publishes),
|
|
|
until = UntilSeqno,
|
|
|
stream = DSStream#ds_stream.ref,
|
|
|
iterator = ItBegin
|
|
|
@@ -277,7 +282,7 @@ fetch(ReplyFun, SessionId, Inflight0, [DSStream | Streams], N, Acc) when N > 0 -
|
|
|
next_seqno = UntilSeqno,
|
|
|
offset_ranges = Ranges ++ [Range]
|
|
|
},
|
|
|
- fetch(ReplyFun, SessionId, Inflight, Streams, N - Size, [Publishes | Acc])
|
|
|
+ fetch(PreprocFun, SessionId, Inflight, Streams, N - Size, [Publishes | Acc])
|
|
|
end;
|
|
|
fetch(_ReplyFun, _SessionId, Inflight, _Streams, _N, Acc) ->
|
|
|
Publishes = lists:append(lists:reverse(Acc)),
|
|
|
@@ -374,19 +379,20 @@ discard_tracks(#{ack := AckedUntil, comp := CompUntil}, Until, Tracks) ->
|
|
|
TAck bor TComp.
|
|
|
|
|
|
replay_range(
|
|
|
- ReplyFun,
|
|
|
+ PreprocFun,
|
|
|
+ Commits,
|
|
|
Range0 = #ds_pubrange{type = ?T_INFLIGHT, id = {_, First}, until = Until, iterator = It},
|
|
|
Acc
|
|
|
) ->
|
|
|
Size = range_size(First, Until),
|
|
|
{ok, ItNext, MessagesUnacked} = emqx_ds:next(?PERSISTENT_MESSAGE_DB, It, Size),
|
|
|
%% Asserting that range is consistent with the message storage state.
|
|
|
- {Replies, {Until, _TracksInitial}} = publish(ReplyFun, First, MessagesUnacked),
|
|
|
+ {Replies, Until} = publish_replay(PreprocFun, Commits, First, MessagesUnacked),
|
|
|
%% Again, we need to keep the iterator pointing past the end of the
|
|
|
%% range, so that we can pick up where we left off.
|
|
|
Range = keep_next_iterator(ItNext, Range0),
|
|
|
{Range, Replies ++ Acc};
|
|
|
-replay_range(_ReplyFun, Range0 = #ds_pubrange{type = ?T_CHECKPOINT}, Acc) ->
|
|
|
+replay_range(_PreprocFun, _Commits, Range0 = #ds_pubrange{type = ?T_CHECKPOINT}, Acc) ->
|
|
|
{Range0, Acc}.
|
|
|
|
|
|
validate_commit(
|
|
|
@@ -419,33 +425,88 @@ get_commit_next(rec, #inflight{next_seqno = NextSeqno}) ->
|
|
|
get_commit_next(comp, #inflight{commits = Commits}) ->
|
|
|
maps:get(rec, Commits).
|
|
|
|
|
|
-publish(ReplyFun, FirstSeqno, Messages) ->
|
|
|
- lists:mapfoldl(
|
|
|
- fun(Message, Acc = {Seqno, _Tracks}) ->
|
|
|
- Reply = ReplyFun(Seqno, Message),
|
|
|
- publish_reply(Reply, Acc)
|
|
|
+publish_fetch(PreprocFun, FirstSeqno, Messages) ->
|
|
|
+ flatmapfoldl(
|
|
|
+ fun(MessageIn, Acc) ->
|
|
|
+ Message = PreprocFun(MessageIn),
|
|
|
+ publish_fetch(Message, Acc)
|
|
|
+ end,
|
|
|
+ FirstSeqno,
|
|
|
+ Messages
|
|
|
+ ).
|
|
|
+
|
|
|
+publish_fetch(#message{qos = ?QOS_0} = Message, Seqno) ->
|
|
|
+ {{undefined, Message}, Seqno};
|
|
|
+publish_fetch(#message{} = Message, Seqno) ->
|
|
|
+ PacketId = seqno_to_packet_id(Seqno),
|
|
|
+ {{PacketId, Message}, next_seqno(Seqno)};
|
|
|
+publish_fetch(Messages, Seqno) ->
|
|
|
+ flatmapfoldl(fun publish_fetch/2, Seqno, Messages).
|
|
|
+
|
|
|
+publish_replay(PreprocFun, Commits, FirstSeqno, Messages) ->
|
|
|
+ #{ack := AckedUntil, comp := CompUntil, rec := RecUntil} = Commits,
|
|
|
+ flatmapfoldl(
|
|
|
+ fun(MessageIn, Acc) ->
|
|
|
+ Message = PreprocFun(MessageIn),
|
|
|
+ publish_replay(Message, AckedUntil, CompUntil, RecUntil, Acc)
|
|
|
end,
|
|
|
- {FirstSeqno, 0},
|
|
|
+ FirstSeqno,
|
|
|
Messages
|
|
|
).
|
|
|
|
|
|
-publish_reply(Replies = [_ | _], Acc) ->
|
|
|
- lists:mapfoldl(fun publish_reply/2, Acc, Replies);
|
|
|
-publish_reply(Reply, {Seqno, Tracks}) when is_function(Reply) ->
|
|
|
- Pub = Reply(seqno_to_packet_id(Seqno)),
|
|
|
- NextSeqno = next_seqno(Seqno),
|
|
|
- NextTracks = add_pub_track(Pub, Tracks),
|
|
|
- {Pub, {NextSeqno, NextTracks}};
|
|
|
-publish_reply(Reply, Acc) ->
|
|
|
- {Reply, Acc}.
|
|
|
-
|
|
|
-add_pub_track({PacketId, Message}, Tracks) when is_integer(PacketId) ->
|
|
|
- case emqx_message:qos(Message) of
|
|
|
- 1 -> ?TRACK_FLAG(?ACK) bor Tracks;
|
|
|
- 2 -> ?TRACK_FLAG(?COMP) bor Tracks;
|
|
|
- _ -> Tracks
|
|
|
+publish_replay(#message{qos = ?QOS_0}, _, _, _, Seqno) ->
|
|
|
+ %% QoS 0 (at most once) messages should not be replayed.
|
|
|
+ {[], Seqno};
|
|
|
+publish_replay(#message{qos = Qos} = Message, AckedUntil, CompUntil, RecUntil, Seqno) ->
|
|
|
+ case Qos of
|
|
|
+ ?QOS_1 when Seqno < AckedUntil ->
|
|
|
+ %% This message has already been acked, so we can skip it.
|
|
|
+ %% We still need to advance seqno, because previously we assigned this message
|
|
|
+ %% a unique Packet Id.
|
|
|
+ {[], next_seqno(Seqno)};
|
|
|
+ ?QOS_2 when Seqno < CompUntil ->
|
|
|
+ %% This message's flow has already been fully completed, so we can skip it.
|
|
|
+ %% We still need to advance seqno, because previously we assigned this message
|
|
|
+ %% a unique Packet Id.
|
|
|
+ {[], next_seqno(Seqno)};
|
|
|
+ ?QOS_2 when Seqno < RecUntil ->
|
|
|
+ %% This message's flow has been partially completed, we need to resend a PUBREL.
|
|
|
+ PacketId = seqno_to_packet_id(Seqno),
|
|
|
+ Pub = {pubrel, PacketId},
|
|
|
+ {Pub, next_seqno(Seqno)};
|
|
|
+ _ ->
|
|
|
+ %% This message flow hasn't been acked and/or received, we need to resend it.
|
|
|
+ PacketId = seqno_to_packet_id(Seqno),
|
|
|
+ Pub = {PacketId, emqx_message:set_flag(dup, true, Message)},
|
|
|
+ {Pub, next_seqno(Seqno)}
|
|
|
end;
|
|
|
-add_pub_track(_Pub, Tracks) ->
|
|
|
+publish_replay([], _, _, _, Seqno) ->
|
|
|
+ {[], Seqno};
|
|
|
+publish_replay(Messages, AckedUntil, CompUntil, RecUntil, Seqno) ->
|
|
|
+ flatmapfoldl(
|
|
|
+ fun(Message, Acc) ->
|
|
|
+ publish_replay(Message, AckedUntil, CompUntil, RecUntil, Acc)
|
|
|
+ end,
|
|
|
+ Seqno,
|
|
|
+ Messages
|
|
|
+ ).
|
|
|
+
|
|
|
+-spec compute_pub_tracks(replies()) -> non_neg_integer().
|
|
|
+compute_pub_tracks(Pubs) ->
|
|
|
+ compute_pub_tracks(Pubs, ?TRACK_FLAGS_NONE).
|
|
|
+
|
|
|
+compute_pub_tracks(_Pubs, Tracks = ?TRACK_FLAGS_ALL) ->
|
|
|
+ Tracks;
|
|
|
+compute_pub_tracks([Pub | Rest], Tracks) ->
|
|
|
+ Track =
|
|
|
+ case Pub of
|
|
|
+ {_PacketId, #message{qos = ?QOS_1}} -> ?TRACK_FLAG(?ACK);
|
|
|
+ {_PacketId, #message{qos = ?QOS_2}} -> ?TRACK_FLAG(?COMP);
|
|
|
+ {pubrel, _PacketId} -> ?TRACK_FLAG(?COMP);
|
|
|
+ _ -> ?TRACK_FLAGS_NONE
|
|
|
+ end,
|
|
|
+ compute_pub_tracks(Rest, Track bor Tracks);
|
|
|
+compute_pub_tracks([], Tracks) ->
|
|
|
Tracks.
|
|
|
|
|
|
keep_next_iterator(ItNext, Range = #ds_pubrange{iterator = ItFirst, misc = Misc}) ->
|
|
|
@@ -550,6 +611,19 @@ shuffle(L0) ->
|
|
|
{_, L} = lists:unzip(L2),
|
|
|
L.
|
|
|
|
|
|
+-spec flatmapfoldl(fun((X, Acc) -> {Y | [Y], Acc}), Acc, [X]) -> {[Y], Acc}.
|
|
|
+flatmapfoldl(_Fun, Acc, []) ->
|
|
|
+ {[], Acc};
|
|
|
+flatmapfoldl(Fun, Acc, [X | Xs]) ->
|
|
|
+ {Ys, NAcc} = Fun(X, Acc),
|
|
|
+ {Zs, FAcc} = flatmapfoldl(Fun, NAcc, Xs),
|
|
|
+ case is_list(Ys) of
|
|
|
+ true ->
|
|
|
+ {Ys ++ Zs, FAcc};
|
|
|
+ _ ->
|
|
|
+ {[Ys | Zs], FAcc}
|
|
|
+ end.
|
|
|
+
|
|
|
ro_transaction(Fun) ->
|
|
|
{atomic, Res} = mria:ro_transaction(?DS_MRIA_SHARD, Fun),
|
|
|
Res.
|