Просмотр исходного кода

feat(util-stream): add few more stream evaluators including folds

Andrew Mayorov 1 год назад
Родитель
Сommit
e3ac700ac8

+ 53 - 11
apps/emqx_utils/src/emqx_utils_stream.erl

@@ -39,7 +39,10 @@
     next/1,
     consume/1,
     consume/2,
-    foreach/2
+    foreach/2,
+    fold/3,
+    fold/4,
+    sweep/2
 ]).
 
 %% Streams from ETS tables
@@ -108,6 +111,7 @@ map(F, S) ->
     end.
 
 %% @doc Make a stream by filtering the underlying stream with a predicate function.
+-spec filter(fun((X) -> boolean()), stream(X)) -> stream(X).
 filter(F, S) ->
     FilterNext = fun FilterNext(St) ->
         case next(St) of
@@ -124,16 +128,6 @@ filter(F, S) ->
     end,
     fun() -> FilterNext(S) end.
 
-%% @doc Consumes the stream and applies the given function to each element.
-foreach(F, S) ->
-    case next(S) of
-        [X | Rest] ->
-            F(X),
-            foreach(F, Rest);
-        [] ->
-            ok
-    end.
-
 %% @doc Drops N first elements from the stream
 -spec drop(non_neg_integer(), stream(T)) -> stream(T).
 drop(N, S) ->
@@ -297,6 +291,54 @@ consume(N, S, Acc) ->
             lists:reverse(Acc)
     end.
 
+%% @doc Consumes the stream and applies the given function to each element.
+-spec foreach(fun((X) -> _), stream(X)) -> ok.
+foreach(F, S) ->
+    case next(S) of
+        [X | Rest] ->
+            F(X),
+            foreach(F, Rest);
+        [] ->
+            ok
+    end.
+
+%% @doc Folds the whole stream, accumulating the result of given function applied
+%% to each element.
+-spec fold(fun((X, Acc) -> Acc), Acc, stream(X)) -> Acc.
+fold(F, Acc, S) ->
+    case next(S) of
+        [X | Rest] ->
+            fold(F, F(X, Acc), Rest);
+        [] ->
+            Acc
+    end.
+
+%% @doc Folds the first N element of the stream, accumulating the result of given
+%% function applied to each element. If there's less than N elements in the given
+%% stream, returns `[]` (a.k.a. empty stream) along with the accumulated value.
+-spec fold(fun((X, Acc) -> Acc), Acc, non_neg_integer(), stream(X)) -> {Acc, stream(X)}.
+fold(_, Acc, 0, S) ->
+    {Acc, S};
+fold(F, Acc, N, S) when N > 0 ->
+    case next(S) of
+        [X | Rest] ->
+            fold(F, F(X, Acc), N - 1, Rest);
+        [] ->
+            {Acc, []}
+    end.
+
+%% @doc Same as `consume/2` but discard the consumed values.
+-spec sweep(non_neg_integer(), stream(X)) -> stream(X).
+sweep(0, S) ->
+    S;
+sweep(N, S) when N > 0 ->
+    case next(S) of
+        [_ | Rest] ->
+            sweep(N - 1, Rest);
+        [] ->
+            []
+    end.
+
 %%
 
 -type select_result(Record, Cont) ::

+ 19 - 0
apps/emqx_utils/test/emqx_utils_stream_tests.erl

@@ -114,6 +114,25 @@ foreach_test() ->
         emqx_utils_stream:consume(emqx_utils_stream:mqueue(100))
     ).
 
+fold_test() ->
+    S = emqx_utils_stream:drop(2, emqx_utils_stream:list([1, 2, 3, 4, 5])),
+    ?assertEqual(
+        3 * 4 * 5,
+        emqx_utils_stream:fold(fun(X, P) -> P * X end, 1, S)
+    ).
+
+fold_n_test() ->
+    S = emqx_utils_stream:repeat(
+        emqx_utils_stream:map(
+            fun(X) -> X * 2 end,
+            emqx_utils_stream:list([1, 2, 3])
+        )
+    ),
+    ?assertMatch(
+        {2 + 4 + 6 + 2 + 4 + 6 + 2, _SRest},
+        emqx_utils_stream:fold(fun(X, Sum) -> Sum + X end, 0, _N = 7, S)
+    ).
+
 chainmap_test() ->
     S = emqx_utils_stream:chainmap(
         fun(N) ->