Просмотр исходного кода

refactor: move the env interpolation function to emqx_schema

also added test cases
Zaiming (Stone) Shi 2 лет назад
Родитель
Сommit
7c5a9e0e20
3 измененных файлов с 153 добавлено и 57 удалено
  1. 59 13
      apps/emqx/src/emqx_schema.erl
  2. 16 44
      apps/emqx/src/emqx_tls_lib.erl
  3. 78 0
      apps/emqx/test/emqx_schema_tests.erl

+ 59 - 13
apps/emqx/src/emqx_schema.erl

@@ -66,7 +66,8 @@
     user_lookup_fun_tr/2,
     validate_alarm_actions/1,
     non_empty_string/1,
-    validations/0
+    validations/0,
+    naive_env_interpolation/1
 ]).
 
 -export([qos/0]).
@@ -825,7 +826,7 @@ fields("crl_cache") ->
     %% same URL.  If they had diverging timeout options, it would be
     %% confusing.
     [
-        {"refresh_interval",
+        {refresh_interval,
             sc(
                 duration(),
                 #{
@@ -833,7 +834,7 @@ fields("crl_cache") ->
                     desc => ?DESC("crl_cache_refresh_interval")
                 }
             )},
-        {"http_timeout",
+        {http_timeout,
             sc(
                 duration(),
                 #{
@@ -841,7 +842,7 @@ fields("crl_cache") ->
                     desc => ?DESC("crl_cache_refresh_http_timeout")
                 }
             )},
-        {"capacity",
+        {capacity,
             sc(
                 pos_integer(),
                 #{
@@ -1358,7 +1359,7 @@ fields("ssl_client_opts") ->
     client_ssl_opts_schema(#{});
 fields("ocsp") ->
     [
-        {"enable_ocsp_stapling",
+        {enable_ocsp_stapling,
             sc(
                 boolean(),
                 #{
@@ -1366,7 +1367,7 @@ fields("ocsp") ->
                     desc => ?DESC("server_ssl_opts_schema_enable_ocsp_stapling")
                 }
             )},
-        {"responder_url",
+        {responder_url,
             sc(
                 url(),
                 #{
@@ -1374,7 +1375,7 @@ fields("ocsp") ->
                     desc => ?DESC("server_ssl_opts_schema_ocsp_responder_url")
                 }
             )},
-        {"issuer_pem",
+        {issuer_pem,
             sc(
                 binary(),
                 #{
@@ -1382,7 +1383,7 @@ fields("ocsp") ->
                     desc => ?DESC("server_ssl_opts_schema_ocsp_issuer_pem")
                 }
             )},
-        {"refresh_interval",
+        {refresh_interval,
             sc(
                 duration(),
                 #{
@@ -1390,7 +1391,7 @@ fields("ocsp") ->
                     desc => ?DESC("server_ssl_opts_schema_ocsp_refresh_interval")
                 }
             )},
-        {"refresh_http_timeout",
+        {refresh_http_timeout,
             sc(
                 duration(),
                 #{
@@ -2317,12 +2318,12 @@ server_ssl_opts_schema(Defaults, IsRanchListener) ->
             Field
          || not IsRanchListener,
             Field <- [
-                {"gc_after_handshake",
+                {gc_after_handshake,
                     sc(boolean(), #{
                         default => false,
                         desc => ?DESC(server_ssl_opts_schema_gc_after_handshake)
                     })},
-                {"ocsp",
+                {ocsp,
                     sc(
                         ref("ocsp"),
                         #{
@@ -2330,7 +2331,7 @@ server_ssl_opts_schema(Defaults, IsRanchListener) ->
                             validator => fun ocsp_inner_validator/1
                         }
                     )},
-                {"enable_crl_check",
+                {enable_crl_check,
                     sc(
                         boolean(),
                         #{
@@ -3106,7 +3107,7 @@ default_listener(ws) ->
             }
     };
 default_listener(SSLListener) ->
-    %% The env variable is resolved in emqx_tls_lib
+    %% The env variable is resolved in emqx_tls_lib by calling naive_env_interpolate
     CertFile = fun(Name) ->
         iolist_to_binary("${EMQX_ETC_DIR}/" ++ filename:join(["certs", Name]))
     end,
@@ -3136,3 +3137,48 @@ default_listener(SSLListener) ->
                     }
             }
     end.
+
+%% @doc This function helps to perform a naive string interpolation which
+%% only looks at the first segment of the string and tries to replace it.
+%% For example
+%%  "$MY_FILE_PATH"
+%%  "${MY_FILE_PATH}"
+%%  "$ENV_VARIABLE/sub/path"
+%%  "${ENV_VARIABLE}/sub/path"
+%%  "${ENV_VARIABLE}\sub\path" # windows
+%% This function returns undefined if the input is undefined
+%% otherwise always return string.
+naive_env_interpolation(undefined) ->
+    undefined;
+naive_env_interpolation(Bin) when is_binary(Bin) ->
+    naive_env_interpolation(unicode:characters_to_list(Bin, utf8));
+naive_env_interpolation("$" ++ Maybe = Original) ->
+    {Env, Tail} = split_path(Maybe),
+    case resolve_env(Env) of
+        {ok, Path} ->
+            filename:join([Path, Tail]);
+        error ->
+            Original
+    end;
+naive_env_interpolation(Other) ->
+    Other.
+
+split_path(Path) ->
+    split_path(Path, []).
+
+split_path([], Acc) ->
+    {lists:reverse(Acc), []};
+split_path([Char | Rest], Acc) when Char =:= $/ orelse Char =:= $\\ ->
+    {lists:reverse(Acc), string:trim(Rest, leading, "/\\")};
+split_path([Char | Rest], Acc) ->
+    split_path(Rest, [Char | Acc]).
+
+resolve_env(Name0) ->
+    Name = string:trim(Name0, both, "{}"),
+    Value = os:getenv(Name),
+    case Value =/= false andalso Value =/= "" of
+        true ->
+            {ok, Value};
+        false ->
+            error
+    end.

+ 16 - 44
apps/emqx/src/emqx_tls_lib.erl

@@ -347,8 +347,7 @@ delete_ssl_files(Dir, NewOpts0, OldOpts0) ->
 delete_old_file(New, Old) when New =:= Old -> ok;
 delete_old_file(_New, _Old = undefined) ->
     ok;
-delete_old_file(_New, Old0) ->
-    Old = resolve_cert_path(Old0),
+delete_old_file(_New, Old) ->
     case is_generated_file(Old) andalso filelib:is_regular(Old) andalso file:delete(Old) of
         ok ->
             ok;
@@ -356,7 +355,7 @@ delete_old_file(_New, Old0) ->
         false ->
             ok;
         {error, Reason} ->
-            ?SLOG(error, #{msg => "failed_to_delete_ssl_file", file_path => Old0, reason => Reason})
+            ?SLOG(error, #{msg => "failed_to_delete_ssl_file", file_path => Old, reason => Reason})
     end.
 
 ensure_ssl_file(_Dir, _KeyPath, SSL, undefined, _Opts) ->
@@ -415,8 +414,7 @@ is_pem(MaybePem) ->
 %% To make it simple, the file is always overwritten.
 %% Also a potentially half-written PEM file (e.g. due to power outage)
 %% can be corrected with an overwrite.
-save_pem_file(Dir0, KeyPath, Pem, DryRun) ->
-    Dir = resolve_cert_path(Dir0),
+save_pem_file(Dir, KeyPath, Pem, DryRun) ->
     Path = pem_file_name(Dir, KeyPath, Pem),
     case filelib:ensure_dir(Path) of
         ok when DryRun ->
@@ -475,7 +473,7 @@ hex_str(Bin) ->
 
 %% @doc Returns 'true' when the file is a valid pem, otherwise {error, Reason}.
 is_valid_pem_file(Path0) ->
-    Path = resolve_cert_path(Path0),
+    Path = resolve_cert_path_for_read(Path0),
     case file:read_file(Path) of
         {ok, Pem} -> is_pem(Pem) orelse {error, not_pem};
         {error, Reason} -> {error, Reason}
@@ -516,11 +514,12 @@ do_drop_invalid_certs([KeyPath | KeyPaths], SSL) ->
 to_server_opts(Type, Opts) ->
     Versions = integral_versions(Type, maps:get(versions, Opts, undefined)),
     Ciphers = integral_ciphers(Versions, maps:get(ciphers, Opts, undefined)),
+    Path = fun(Key) -> resolve_cert_path_for_read_strict(maps:get(Key, Opts, undefined)) end,
     filter(
         maps:to_list(Opts#{
-            keyfile => resolve_cert_path_strict(maps:get(keyfile, Opts, undefined)),
-            certfile => resolve_cert_path_strict(maps:get(certfile, Opts, undefined)),
-            cacertfile => resolve_cert_path_strict(maps:get(cacertfile, Opts, undefined)),
+            keyfile => Path(keyfile),
+            certfile => Path(certfile),
+            cacertfile => Path(cacertfile),
             ciphers => Ciphers,
             versions => Versions
         })
@@ -538,11 +537,12 @@ to_client_opts(Opts) ->
 to_client_opts(Type, Opts) ->
     GetD = fun(Key, Default) -> fuzzy_map_get(Key, Opts, Default) end,
     Get = fun(Key) -> GetD(Key, undefined) end,
+    Path = fun(Key) -> resolve_cert_path_for_read_strict(Get(Key)) end,
     case GetD(enable, false) of
         true ->
-            KeyFile = resolve_cert_path_strict(Get(keyfile)),
-            CertFile = resolve_cert_path_strict(Get(certfile)),
-            CAFile = resolve_cert_path_strict(Get(cacertfile)),
+            KeyFile = Path(keyfile),
+            CertFile = Path(certfile),
+            CAFile = Path(cacertfile),
             Verify = GetD(verify, verify_none),
             SNI = ensure_sni(Get(server_name_indication)),
             Versions = integral_versions(Type, Get(versions)),
@@ -564,8 +564,8 @@ to_client_opts(Type, Opts) ->
             []
     end.
 
-resolve_cert_path_strict(Path) ->
-    case resolve_cert_path(Path) of
+resolve_cert_path_for_read_strict(Path) ->
+    case resolve_cert_path_for_read(Path) of
         undefined ->
             undefined;
         ResolvedPath ->
@@ -586,36 +586,8 @@ resolve_cert_path_strict(Path) ->
             end
     end.
 
-resolve_cert_path(undefined) ->
-    undefined;
-resolve_cert_path(Path) ->
-    case ensure_str(Path) of
-        "$" ++ Maybe ->
-            naive_env_resolver(Maybe);
-        Other ->
-            Other
-    end.
-
-%% resolves a file path like "ENV_VARIABLE/sub/path" or "{ENV_VARIABLE}/sub/path"
-%% in windows, it could be "ENV_VARIABLE/sub\path" or "{ENV_VARIABLE}/sub\path"
-naive_env_resolver(Maybe) ->
-    case string:split(Maybe, "/") of
-        [_] ->
-            Maybe;
-        [Env, SubPath] ->
-            case os:getenv(trim_env_name(Env)) of
-                false ->
-                    SubPath;
-                "" ->
-                    SubPath;
-                EnvValue ->
-                    filename:join(EnvValue, SubPath)
-            end
-    end.
-
-%% delete the first and last curly braces
-trim_env_name(Env) ->
-    string:trim(Env, both, "{}").
+resolve_cert_path_for_read(Path) ->
+    emqx_schema:naive_env_interpolation(Path).
 
 filter([]) -> [];
 filter([{_, undefined} | T]) -> filter(T);

+ 78 - 0
apps/emqx/test/emqx_schema_tests.erl

@@ -513,3 +513,81 @@ url_type_test_() ->
             typerefl:from_string(emqx_schema:url(), <<"">>)
         )
     ].
+
+env_test_() ->
+    Do = fun emqx_schema:naive_env_interpolation/1,
+    [
+        {"undefined", fun() -> ?assertEqual(undefined, Do(undefined)) end},
+        {"full env abs path",
+            with_env_fn(
+                "MY_FILE",
+                "/path/to/my/file",
+                fun() -> ?assertEqual("/path/to/my/file", Do("$MY_FILE")) end
+            )},
+        {"full env relative path",
+            with_env_fn(
+                "MY_FILE",
+                "path/to/my/file",
+                fun() -> ?assertEqual("path/to/my/file", Do("${MY_FILE}")) end
+            )},
+        %% we can not test windows style file join though
+        {"windows style",
+            with_env_fn(
+                "MY_FILE",
+                "path\\to\\my\\file",
+                fun() -> ?assertEqual("path\\to\\my\\file", Do("$MY_FILE")) end
+            )},
+        {"dir no {}",
+            with_env_fn(
+                "MY_DIR",
+                "/mydir",
+                fun() -> ?assertEqual("/mydir/foobar", Do(<<"$MY_DIR/foobar">>)) end
+            )},
+        {"dir with {}",
+            with_env_fn(
+                "MY_DIR",
+                "/mydir",
+                fun() -> ?assertEqual("/mydir/foobar", Do(<<"${MY_DIR}/foobar">>)) end
+            )},
+        %% a trailing / should not cause the sub path to become absolute
+        {"env dir with trailing /",
+            with_env_fn(
+                "MY_DIR",
+                "/mydir//",
+                fun() -> ?assertEqual("/mydir/foobar", Do(<<"${MY_DIR}/foobar">>)) end
+            )},
+        {"string dir with doulbe /",
+            with_env_fn(
+                "MY_DIR",
+                "/mydir/",
+                fun() -> ?assertEqual("/mydir/foobar", Do(<<"${MY_DIR}//foobar">>)) end
+            )},
+        {"env not found",
+            with_env_fn(
+                "MY_DIR",
+                "/mydir/",
+                fun() -> ?assertEqual("${MY_DIR2}//foobar", Do(<<"${MY_DIR2}//foobar">>)) end
+            )}
+    ].
+
+with_env_fn(Name, Value, F) ->
+    fun() ->
+        with_envs(F, [{Name, Value}])
+    end.
+
+with_envs(Fun, Envs) ->
+    with_envs(Fun, [], Envs).
+
+with_envs(Fun, Args, [{_Name, _Value} | _] = Envs) ->
+    set_envs(Envs),
+    try
+        apply(Fun, Args)
+    after
+        unset_envs(Envs)
+    end.
+
+set_envs([{_Name, _Value} | _] = Envs) ->
+    lists:map(fun({Name, Value}) -> os:putenv(Name, Value) end, Envs).
+
+unset_envs([{_Name, _Value} | _] = Envs) ->
+    lists:map(fun({Name, _}) -> os:unsetenv(Name) end, Envs).