Przeglądaj źródła

fix(ds): ensure store batch is idempotent wrt generations

Andrew Mayorov 1 rok temu
rodzic
commit
2cd357a5bd

+ 30 - 18
apps/emqx_durable_storage/src/emqx_ds_storage_layer.erl

@@ -251,23 +251,13 @@ drop_shard(Shard) ->
     emqx_ds:message_store_opts()
     emqx_ds:message_store_opts()
 ) ->
 ) ->
     emqx_ds:store_batch_result().
     emqx_ds:store_batch_result().
-store_batch(Shard, Messages0, Options) ->
-    %% We always store messages in the current generation:
-    GenId = generation_current(Shard),
-    #{module := Mod, data := GenData, since := Since} = generation_get(Shard, GenId),
-    case Messages0 of
-        [{Time, _Msg} | Rest] when Time < Since ->
-            %% FIXME: log / feedback
-            Messages = skip_outdated_messages(Since, Rest);
-        _ ->
-            Messages = Messages0
-    end,
-    Mod:store_batch(Shard, GenData, Messages, Options).
-
-skip_outdated_messages(Since, [{Time, _Msg} | Rest]) when Time < Since ->
-    skip_outdated_messages(Since, Rest);
-skip_outdated_messages(_Since, Messages) ->
-    Messages.
+store_batch(Shard, Messages = [{Time, _Msg} | _], Options) ->
+    %% NOTE
+    %% We assume that batches do not span generations. Callers should enforce this.
+    #{module := Mod, data := GenData} = generation_at(Shard, Time),
+    Mod:store_batch(Shard, GenData, Messages, Options);
+store_batch(_Shard, [], _Options) ->
+    ok.
 
 
 -spec get_streams(shard_id(), emqx_ds:topic_filter(), emqx_ds:time()) ->
 -spec get_streams(shard_id(), emqx_ds:topic_filter(), emqx_ds:time()) ->
     [{integer(), stream()}].
     [{integer(), stream()}].
@@ -715,7 +705,7 @@ create_new_shard_schema(ShardId, DB, CFRefs, Prototype) ->
     {gen_id(), shard_schema(), cf_refs()}.
     {gen_id(), shard_schema(), cf_refs()}.
 new_generation(ShardId, DB, Schema0, Since) ->
 new_generation(ShardId, DB, Schema0, Since) ->
     #{current_generation := PrevGenId, prototype := {Mod, ModConf}} = Schema0,
     #{current_generation := PrevGenId, prototype := {Mod, ModConf}} = Schema0,
-    GenId = PrevGenId + 1,
+    GenId = next_generation_id(PrevGenId),
     {GenData, NewCFRefs} = Mod:create(ShardId, DB, GenId, ModConf),
     {GenData, NewCFRefs} = Mod:create(ShardId, DB, GenId, ModConf),
     GenSchema = #{
     GenSchema = #{
         module => Mod,
         module => Mod,
@@ -731,6 +721,14 @@ new_generation(ShardId, DB, Schema0, Since) ->
     },
     },
     {GenId, Schema, NewCFRefs}.
     {GenId, Schema, NewCFRefs}.
 
 
+-spec next_generation_id(gen_id()) -> gen_id().
+next_generation_id(GenId) ->
+    GenId + 1.
+
+-spec prev_generation_id(gen_id()) -> gen_id().
+prev_generation_id(GenId) when GenId > 0 ->
+    GenId - 1.
+
 %% @doc Commit current state of the server to both rocksdb and the persistent term
 %% @doc Commit current state of the server to both rocksdb and the persistent term
 -spec commit_metadata(server_state()) -> ok.
 -spec commit_metadata(server_state()) -> ok.
 commit_metadata(#s{shard_id = ShardId, schema = Schema, shard = Runtime, db = DB}) ->
 commit_metadata(#s{shard_id = ShardId, schema = Schema, shard = Runtime, db = DB}) ->
@@ -854,6 +852,20 @@ generations_since(Shard, Since) ->
         Schema
         Schema
     ).
     ).
 
 
+-spec generation_at(shard_id(), emqx_ds:time()) -> generation().
+generation_at(Shard, Time) ->
+    Schema = #{current_generation := Current} = get_schema_runtime(Shard),
+    generation_at(Time, Current, Schema).
+
+generation_at(Time, GenId, Schema) ->
+    #{?GEN_KEY(GenId) := Gen} = Schema,
+    case Gen of
+        #{since := Since} when Time < Since andalso GenId > 0 ->
+            generation_at(Time, prev_generation_id(GenId), Schema);
+        _ ->
+            Gen
+    end.
+
 -define(PERSISTENT_TERM(SHARD), {emqx_ds_storage_layer, SHARD}).
 -define(PERSISTENT_TERM(SHARD), {emqx_ds_storage_layer, SHARD}).
 
 
 -spec get_schema_runtime(shard_id()) -> shard().
 -spec get_schema_runtime(shard_id()) -> shard().

+ 60 - 38
apps/emqx_durable_storage/test/emqx_ds_storage_snapshot_SUITE.erl

@@ -13,7 +13,7 @@
 %% See the License for the specific language governing permissions and
 %% See the License for the specific language governing permissions and
 %% limitations under the License.
 %% limitations under the License.
 %%--------------------------------------------------------------------
 %%--------------------------------------------------------------------
--module(emqx_ds_storage_snapshot_SUITE).
+-module(emqx_ds_storage_SUITE).
 
 
 -compile(export_all).
 -compile(export_all).
 -compile(nowarn_export_all).
 -compile(nowarn_export_all).
@@ -27,18 +27,37 @@ opts() ->
 
 
 %%
 %%
 
 
+t_idempotent_store_batch(_Config) ->
+    Shard = {?FUNCTION_NAME, _ShardId = <<"42">>},
+    {ok, Pid} = emqx_ds_storage_layer:start_link(Shard, opts()),
+    %% Push some messages to the shard.
+    Msgs1 = [gen_message(N) || N <- lists:seq(10, 20)],
+    GenTs = 30,
+    Msgs2 = [gen_message(N) || N <- lists:seq(40, 50)],
+    ?assertEqual(ok, emqx_ds_storage_layer:store_batch(Shard, batch(Msgs1), #{})),
+    %% Add new generation and push the same batch + some more.
+    ?assertEqual(ok, emqx_ds_storage_layer:add_generation(Shard, GenTs)),
+    ?assertEqual(ok, emqx_ds_storage_layer:store_batch(Shard, batch(Msgs1), #{})),
+    ?assertEqual(ok, emqx_ds_storage_layer:store_batch(Shard, batch(Msgs2), #{})),
+    %% First batch should have been handled idempotently.
+    ?assertEqual(
+        Msgs1 ++ Msgs2,
+        lists:keysort(#message.timestamp, consume(Shard, ['#']))
+    ),
+    ok = stop_shard(Pid).
+
 t_snapshot_take_restore(_Config) ->
 t_snapshot_take_restore(_Config) ->
     Shard = {?FUNCTION_NAME, _ShardId = <<"42">>},
     Shard = {?FUNCTION_NAME, _ShardId = <<"42">>},
     {ok, Pid} = emqx_ds_storage_layer:start_link(Shard, opts()),
     {ok, Pid} = emqx_ds_storage_layer:start_link(Shard, opts()),
 
 
     %% Push some messages to the shard.
     %% Push some messages to the shard.
     Msgs1 = [gen_message(N) || N <- lists:seq(1000, 2000)],
     Msgs1 = [gen_message(N) || N <- lists:seq(1000, 2000)],
-    ?assertEqual(ok, emqx_ds_storage_layer:store_batch(Shard, mk_batch(Msgs1), #{})),
+    ?assertEqual(ok, emqx_ds_storage_layer:store_batch(Shard, batch(Msgs1), #{})),
 
 
     %% Add new generation and push some more.
     %% Add new generation and push some more.
     ?assertEqual(ok, emqx_ds_storage_layer:add_generation(Shard, 3000)),
     ?assertEqual(ok, emqx_ds_storage_layer:add_generation(Shard, 3000)),
     Msgs2 = [gen_message(N) || N <- lists:seq(4000, 5000)],
     Msgs2 = [gen_message(N) || N <- lists:seq(4000, 5000)],
-    ?assertEqual(ok, emqx_ds_storage_layer:store_batch(Shard, mk_batch(Msgs2), #{})),
+    ?assertEqual(ok, emqx_ds_storage_layer:store_batch(Shard, batch(Msgs2), #{})),
     ?assertEqual(ok, emqx_ds_storage_layer:add_generation(Shard, 6000)),
     ?assertEqual(ok, emqx_ds_storage_layer:add_generation(Shard, 6000)),
 
 
     %% Take a snapshot of the shard.
     %% Take a snapshot of the shard.
@@ -46,11 +65,10 @@ t_snapshot_take_restore(_Config) ->
 
 
     %% Push even more messages to the shard AFTER taking the snapshot.
     %% Push even more messages to the shard AFTER taking the snapshot.
     Msgs3 = [gen_message(N) || N <- lists:seq(7000, 8000)],
     Msgs3 = [gen_message(N) || N <- lists:seq(7000, 8000)],
-    ?assertEqual(ok, emqx_ds_storage_layer:store_batch(Shard, mk_batch(Msgs3), #{})),
+    ?assertEqual(ok, emqx_ds_storage_layer:store_batch(Shard, batch(Msgs3), #{})),
 
 
     %% Destroy the shard.
     %% Destroy the shard.
-    _ = unlink(Pid),
-    ok = proc_lib:stop(Pid, shutdown, infinity),
+    ok = stop_shard(Pid),
     ok = emqx_ds_storage_layer:drop_shard(Shard),
     ok = emqx_ds_storage_layer:drop_shard(Shard),
 
 
     %% Restore the shard from the snapshot.
     %% Restore the shard from the snapshot.
@@ -64,12 +82,41 @@ t_snapshot_take_restore(_Config) ->
         lists:keysort(#message.timestamp, consume(Shard, ['#']))
         lists:keysort(#message.timestamp, consume(Shard, ['#']))
     ).
     ).
 
 
-mk_batch(Msgs) ->
-    [{emqx_message:timestamp(Msg, microsecond), Msg} || Msg <- Msgs].
+transfer_snapshot(Reader, Writer) ->
+    ChunkSize = rand:uniform(1024),
+    ReadResult = emqx_ds_storage_snapshot:read_chunk(Reader, ChunkSize),
+    ?assertMatch({RStatus, _, _} when RStatus == next; RStatus == last, ReadResult),
+    {RStatus, Chunk, NReader} = ReadResult,
+    Data = iolist_to_binary(Chunk),
+    {WStatus, NWriter} = emqx_ds_storage_snapshot:write_chunk(Writer, Data),
+    %% Verify idempotency.
+    ?assertMatch(
+        {WStatus, NWriter},
+        emqx_ds_storage_snapshot:write_chunk(NWriter, Data)
+    ),
+    %% Verify convergence.
+    ?assertEqual(
+        RStatus,
+        WStatus,
+        #{reader => NReader, writer => NWriter}
+    ),
+    case WStatus of
+        last ->
+            ?assertEqual(ok, emqx_ds_storage_snapshot:release_reader(NReader)),
+            ?assertEqual(ok, emqx_ds_storage_snapshot:release_writer(NWriter)),
+            ok;
+        next ->
+            transfer_snapshot(NReader, NWriter)
+    end.
+
+%%
+
+batch(Msgs) ->
+    [{emqx_message:timestamp(Msg), Msg} || Msg <- Msgs].
 
 
 gen_message(N) ->
 gen_message(N) ->
     Topic = emqx_topic:join([<<"foo">>, <<"bar">>, integer_to_binary(N)]),
     Topic = emqx_topic:join([<<"foo">>, <<"bar">>, integer_to_binary(N)]),
-    message(Topic, integer_to_binary(N), N * 100).
+    message(Topic, crypto:strong_rand_bytes(16), N).
 
 
 message(Topic, Payload, PublishedAt) ->
 message(Topic, Payload, PublishedAt) ->
     #message{
     #message{
@@ -80,35 +127,6 @@ message(Topic, Payload, PublishedAt) ->
         id = emqx_guid:gen()
         id = emqx_guid:gen()
     }.
     }.
 
 
-transfer_snapshot(Reader, Writer) ->
-    ChunkSize = rand:uniform(1024),
-    case emqx_ds_storage_snapshot:read_chunk(Reader, ChunkSize) of
-        {RStatus, Chunk, NReader} ->
-            Data = iolist_to_binary(Chunk),
-            {WStatus, NWriter} = emqx_ds_storage_snapshot:write_chunk(Writer, Data),
-            %% Verify idempotency.
-            ?assertEqual(
-                {WStatus, NWriter},
-                emqx_ds_storage_snapshot:write_chunk(Writer, Data)
-            ),
-            %% Verify convergence.
-            ?assertEqual(
-                RStatus,
-                WStatus,
-                #{reader => NReader, writer => NWriter}
-            ),
-            case WStatus of
-                last ->
-                    ?assertEqual(ok, emqx_ds_storage_snapshot:release_reader(NReader)),
-                    ?assertEqual(ok, emqx_ds_storage_snapshot:release_writer(NWriter)),
-                    ok;
-                next ->
-                    transfer_snapshot(NReader, NWriter)
-            end;
-        {error, Reason} ->
-            {error, Reason, Reader}
-    end.
-
 consume(Shard, TopicFilter) ->
 consume(Shard, TopicFilter) ->
     consume(Shard, TopicFilter, 0).
     consume(Shard, TopicFilter, 0).
 
 
@@ -132,6 +150,10 @@ consume_stream(Shard, It) ->
             []
             []
     end.
     end.
 
 
+stop_shard(Pid) ->
+    _ = unlink(Pid),
+    proc_lib:stop(Pid, shutdown, infinity).
+
 %%
 %%
 
 
 all() -> emqx_common_test_helpers:all(?MODULE).
 all() -> emqx_common_test_helpers:all(?MODULE).