emqx_auth_pgsql_cli.erl 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. %%--------------------------------------------------------------------
  2. %% Copyright (c) 2020 EMQ Technologies Co., Ltd. All Rights Reserved.
  3. %%
  4. %% Licensed under the Apache License, Version 2.0 (the "License");
  5. %% you may not use this file except in compliance with the License.
  6. %% You may obtain a copy of the License at
  7. %%
  8. %% http://www.apache.org/licenses/LICENSE-2.0
  9. %%
  10. %% Unless required by applicable law or agreed to in writing, software
  11. %% distributed under the License is distributed on an "AS IS" BASIS,
  12. %% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. %% See the License for the specific language governing permissions and
  14. %% limitations under the License.
  15. %%--------------------------------------------------------------------
  16. -module(emqx_auth_pgsql_cli).
  17. -behaviour(ecpool_worker).
  18. -include("emqx_auth_pgsql.hrl").
  19. -include_lib("emqx/include/emqx.hrl").
  20. -include_lib("emqx/include/logger.hrl").
  21. -export([connect/1]).
  22. -export([parse_query/2]).
  23. -export([ equery/4
  24. , equery/3
  25. ]).
  26. -type client_info() :: #{username := _,
  27. clientid := _,
  28. peerhost := _,
  29. _ => _}.
  30. %%--------------------------------------------------------------------
  31. %% Avoid SQL Injection: Parse SQL to Parameter Query.
  32. %%--------------------------------------------------------------------
  33. parse_query(_Par, undefined) ->
  34. undefined;
  35. parse_query(Par, Sql) ->
  36. case re:run(Sql, "'%[ucCad]'", [global, {capture, all, list}]) of
  37. {match, Variables} ->
  38. Params = [Var || [Var] <- Variables],
  39. {atom_to_list(Par), Params};
  40. nomatch ->
  41. {atom_to_list(Par), []}
  42. end.
  43. pgvar(Sql, Params) ->
  44. Vars = ["$" ++ integer_to_list(I) || I <- lists:seq(1, length(Params))],
  45. lists:foldl(fun({Param, Var}, S) ->
  46. re:replace(S, Param, Var, [{return, list}])
  47. end, Sql, lists:zip(Params, Vars)).
  48. %%--------------------------------------------------------------------
  49. %% PostgreSQL Connect/Query
  50. %%--------------------------------------------------------------------
  51. %% Due to a bug in epgsql the caluse for `econnrefused` is not recognised by
  52. %% dialyzer, result in this error:
  53. %% The pattern {'error', Reason = 'econnrefused'} can never match the type ...
  54. %% https://github.com/epgsql/epgsql/issues/246
  55. -dialyzer([{nowarn_function, [connect/1]}]).
  56. connect(Opts) ->
  57. Host = proplists:get_value(host, Opts),
  58. Username = proplists:get_value(username, Opts),
  59. Password = proplists:get_value(password, Opts),
  60. case epgsql:connect(Host, Username, Password, conn_opts(Opts)) of
  61. {ok, C} ->
  62. conn_post(C),
  63. {ok, C};
  64. {error, Reason = econnrefused} ->
  65. ?LOG(error, "[Postgres] Can't connect to Postgres server: Connection refused."),
  66. {error, Reason};
  67. {error, Reason = invalid_authorization_specification} ->
  68. ?LOG(error, "[Postgres] Can't connect to Postgres server: Invalid authorization specification."),
  69. {error, Reason};
  70. {error, Reason = invalid_password} ->
  71. ?LOG(error, "[Postgres] Can't connect to Postgres server: Invalid password."),
  72. {error, Reason};
  73. {error, Reason} ->
  74. ?LOG(error, "[Postgres] Can't connect to Postgres server: ~p", [Reason]),
  75. {error, Reason}
  76. end.
  77. conn_post(Connection) ->
  78. lists:foreach(fun(Par) ->
  79. Sql0 = application:get_env(?APP, Par, undefined),
  80. case parse_query(Par, Sql0) of
  81. undefined -> ok;
  82. {_, Params} ->
  83. Sql = pgvar(Sql0, Params),
  84. epgsql:parse(Connection, atom_to_list(Par), Sql, [])
  85. end
  86. end, [auth_query, acl_query, super_query]).
  87. conn_opts(Opts) ->
  88. conn_opts(Opts, []).
  89. conn_opts([], Acc) ->
  90. Acc;
  91. conn_opts([Opt = {database, _}|Opts], Acc) ->
  92. conn_opts(Opts, [Opt|Acc]);
  93. conn_opts([Opt = {ssl, _}|Opts], Acc) ->
  94. conn_opts(Opts, [Opt|Acc]);
  95. conn_opts([Opt = {port, _}|Opts], Acc) ->
  96. conn_opts(Opts, [Opt|Acc]);
  97. conn_opts([Opt = {timeout, _}|Opts], Acc) ->
  98. conn_opts(Opts, [Opt|Acc]);
  99. conn_opts([Opt = {ssl_opts, _}|Opts], Acc) ->
  100. conn_opts(Opts, [Opt|Acc]);
  101. conn_opts([_Opt|Opts], Acc) ->
  102. conn_opts(Opts, Acc).
  103. -spec(equery(atom(), string() | epgsql:statement(), Parameters::[any()]) -> {ok, ColumnsDescription :: [any()], RowsValues :: [any()]} | {error, any()} ).
  104. equery(Pool, Sql, Params) ->
  105. ecpool:with_client(Pool, fun(C) -> epgsql:prepared_query(C, Sql, Params) end).
  106. -spec(equery(atom(), string() | epgsql:statement(), Parameters::[any()], client_info()) -> {ok, ColumnsDescription :: [any()], RowsValues :: [any()]} | {error, any()} ).
  107. equery(Pool, Sql, Params, ClientInfo) ->
  108. ecpool:with_client(Pool, fun(C) -> epgsql:prepared_query(C, Sql, replvar(Params, ClientInfo)) end).
  109. replvar(Params, ClientInfo) ->
  110. replvar(Params, ClientInfo, []).
  111. replvar([], _ClientInfo, Acc) ->
  112. lists:reverse(Acc);
  113. replvar(["'%u'" | Params], ClientInfo = #{username := Username}, Acc) ->
  114. replvar(Params, ClientInfo, [Username | Acc]);
  115. replvar(["'%c'" | Params], ClientInfo = #{clientid := ClientId}, Acc) ->
  116. replvar(Params, ClientInfo, [ClientId | Acc]);
  117. replvar(["'%a'" | Params], ClientInfo = #{peerhost := IpAddr}, Acc) ->
  118. replvar(Params, ClientInfo, [inet_parse:ntoa(IpAddr) | Acc]);
  119. replvar(["'%C'" | Params], ClientInfo, Acc) ->
  120. replvar(Params, ClientInfo, [safe_get(cn, ClientInfo)| Acc]);
  121. replvar(["'%d'" | Params], ClientInfo, Acc) ->
  122. replvar(Params, ClientInfo, [safe_get(dn, ClientInfo)| Acc]);
  123. replvar([Param | Params], ClientInfo, Acc) ->
  124. replvar(Params, ClientInfo, [Param | Acc]).
  125. safe_get(K, ClientInfo) ->
  126. bin(maps:get(K, ClientInfo, undefined)).
  127. bin(A) when is_atom(A) -> atom_to_binary(A, utf8);
  128. bin(B) when is_binary(B) -> B;
  129. bin(X) -> X.