emqx_auth_pgsql_cli.erl 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. %%--------------------------------------------------------------------
  2. %% Copyright (c) 2020 EMQ Technologies Co., Ltd. All Rights Reserved.
  3. %%
  4. %% Licensed under the Apache License, Version 2.0 (the "License");
  5. %% you may not use this file except in compliance with the License.
  6. %% You may obtain a copy of the License at
  7. %%
  8. %% http://www.apache.org/licenses/LICENSE-2.0
  9. %%
  10. %% Unless required by applicable law or agreed to in writing, software
  11. %% distributed under the License is distributed on an "AS IS" BASIS,
  12. %% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. %% See the License for the specific language governing permissions and
  14. %% limitations under the License.
  15. %%--------------------------------------------------------------------
  16. -module(emqx_auth_pgsql_cli).
  17. -behaviour(ecpool_worker).
  18. -include("emqx_auth_pgsql.hrl").
  19. -include_lib("emqx/include/emqx.hrl").
  20. -include_lib("emqx/include/logger.hrl").
  21. -export([connect/1]).
  22. -export([parse_query/2]).
  23. -export([ equery/4
  24. , equery/3
  25. ]).
  26. -type client_info() :: #{username:=_, clientid:=_, peerhost:=_, _=>_}.
  27. %%--------------------------------------------------------------------
  28. %% Avoid SQL Injection: Parse SQL to Parameter Query.
  29. %%--------------------------------------------------------------------
  30. parse_query(_Par, undefined) ->
  31. undefined;
  32. parse_query(Par, Sql) ->
  33. case re:run(Sql, "'%[ucCad]'", [global, {capture, all, list}]) of
  34. {match, Variables} ->
  35. Params = [Var || [Var] <- Variables],
  36. {atom_to_list(Par), Params};
  37. nomatch ->
  38. {atom_to_list(Par), []}
  39. end.
  40. pgvar(Sql, Params) ->
  41. Vars = ["$" ++ integer_to_list(I) || I <- lists:seq(1, length(Params))],
  42. lists:foldl(fun({Param, Var}, S) ->
  43. re:replace(S, Param, Var, [{return, list}])
  44. end, Sql, lists:zip(Params, Vars)).
  45. %%--------------------------------------------------------------------
  46. %% PostgreSQL Connect/Query
  47. %%--------------------------------------------------------------------
  48. %% Due to a bug in epgsql the caluse for `econnrefused` is not recognised by
  49. %% dialyzer, result in this error:
  50. %% The pattern {'error', Reason = 'econnrefused'} can never match the type ...
  51. %% https://github.com/epgsql/epgsql/issues/246
  52. -dialyzer([{nowarn_function, [connect/1]}]).
  53. connect(Opts) ->
  54. Host = proplists:get_value(host, Opts),
  55. Username = proplists:get_value(username, Opts),
  56. Password = proplists:get_value(password, Opts),
  57. case epgsql:connect(Host, Username, Password, conn_opts(Opts)) of
  58. {ok, C} ->
  59. conn_post(C),
  60. {ok, C};
  61. {error, Reason = econnrefused} ->
  62. ?LOG(error, "[Postgres] Can't connect to Postgres server: Connection refused."),
  63. {error, Reason};
  64. {error, Reason = invalid_authorization_specification} ->
  65. ?LOG(error, "[Postgres] Can't connect to Postgres server: Invalid authorization specification."),
  66. {error, Reason};
  67. {error, Reason = invalid_password} ->
  68. ?LOG(error, "[Postgres] Can't connect to Postgres server: Invalid password."),
  69. {error, Reason};
  70. {error, Reason} ->
  71. ?LOG(error, "[Postgres] Can't connect to Postgres server: ~p", [Reason]),
  72. {error, Reason}
  73. end.
  74. conn_post(Connection) ->
  75. lists:foreach(fun(Par) ->
  76. Sql0 = application:get_env(?APP, Par, undefined),
  77. case parse_query(Par, Sql0) of
  78. undefined -> ok;
  79. {_, Params} ->
  80. Sql = pgvar(Sql0, Params),
  81. epgsql:parse(Connection, atom_to_list(Par), Sql, [])
  82. end
  83. end, [auth_query, acl_query, super_query]).
  84. conn_opts(Opts) ->
  85. conn_opts(Opts, []).
  86. conn_opts([], Acc) ->
  87. Acc;
  88. conn_opts([Opt = {database, _}|Opts], Acc) ->
  89. conn_opts(Opts, [Opt|Acc]);
  90. conn_opts([Opt = {ssl, _}|Opts], Acc) ->
  91. conn_opts(Opts, [Opt|Acc]);
  92. conn_opts([Opt = {port, _}|Opts], Acc) ->
  93. conn_opts(Opts, [Opt|Acc]);
  94. conn_opts([Opt = {timeout, _}|Opts], Acc) ->
  95. conn_opts(Opts, [Opt|Acc]);
  96. conn_opts([Opt = {ssl_opts, _}|Opts], Acc) ->
  97. conn_opts(Opts, [Opt|Acc]);
  98. conn_opts([_Opt|Opts], Acc) ->
  99. conn_opts(Opts, Acc).
  100. -spec(equery(atom(), string() | epgsql:statement(), Parameters::[any()]) -> {ok, ColumnsDescription :: [any()], RowsValues :: [any()]} | {error, any()} ).
  101. equery(Pool, Sql, Params) ->
  102. ecpool:with_client(Pool, fun(C) -> epgsql:prepared_query(C, Sql, Params) end).
  103. -spec(equery(atom(), string() | epgsql:statement(), Parameters::[any()], client_info()) -> {ok, ColumnsDescription :: [any()], RowsValues :: [any()]} | {error, any()} ).
  104. equery(Pool, Sql, Params, ClientInfo) ->
  105. ecpool:with_client(Pool, fun(C) -> epgsql:prepared_query(C, Sql, replvar(Params, ClientInfo)) end).
  106. replvar(Params, ClientInfo) ->
  107. replvar(Params, ClientInfo, []).
  108. replvar([], _ClientInfo, Acc) ->
  109. lists:reverse(Acc);
  110. replvar(["'%u'" | Params], ClientInfo = #{username := Username}, Acc) ->
  111. replvar(Params, ClientInfo, [Username | Acc]);
  112. replvar(["'%c'" | Params], ClientInfo = #{clientid := ClientId}, Acc) ->
  113. replvar(Params, ClientInfo, [ClientId | Acc]);
  114. replvar(["'%a'" | Params], ClientInfo = #{peerhost := IpAddr}, Acc) ->
  115. replvar(Params, ClientInfo, [inet_parse:ntoa(IpAddr) | Acc]);
  116. replvar(["'%C'" | Params], ClientInfo, Acc) ->
  117. replvar(Params, ClientInfo, [safe_get(cn, ClientInfo)| Acc]);
  118. replvar(["'%d'" | Params], ClientInfo, Acc) ->
  119. replvar(Params, ClientInfo, [safe_get(dn, ClientInfo)| Acc]);
  120. replvar([Param | Params], ClientInfo, Acc) ->
  121. replvar(Params, ClientInfo, [Param | Acc]).
  122. safe_get(K, ClientInfo) ->
  123. bin(maps:get(K, ClientInfo, undefined)).
  124. bin(A) when is_atom(A) -> atom_to_binary(A, utf8);
  125. bin(B) when is_binary(B) -> B;
  126. bin(X) -> X.