Skip to content

Commit 9cb9ce4

Browse files
committed
TLS: Fixed issues found during PR review
1 parent ceab158 commit 9cb9ce4

19 files changed

+475
-297
lines changed

.github/workflows/linux_ssl.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ jobs:
1212
uses: Enmk/clickhouse-cpp/.github/workflows/linux.yml@master
1313
with:
1414
extra_cmake_flags: -DWITH_OPENSSL=ON
15-
extra_install: libssl1.1 libssl-dev
16-
gtest_args: --gtest_filter="-*LocalhostTLS*"
15+
extra_install: libssl-dev
16+
# gtest_args: --gtest_filter="-*LocalhostTLS*"

CMakeLists.txt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,30 @@ INCLUDE (cmake/cpp17.cmake)
44
INCLUDE (cmake/subdirs.cmake)
55
INCLUDE (cmake/openssl.cmake)
66

7-
OPTION(BUILD_BENCHMARK "Build benchmark" OFF)
8-
OPTION(BUILD_TESTS "Build tests" OFF)
9-
OPTION(WITH_OPENSSL "Use OpenSSL for TLS connections" OFF)
7+
OPTION (BUILD_BENCHMARK "Build benchmark" OFF)
8+
OPTION (BUILD_TESTS "Build tests" OFF)
9+
OPTION (WITH_OPENSSL "Use OpenSSL for TLS connections" OFF)
1010

1111
PROJECT (CLICKHOUSE-CLIENT)
1212

1313
USE_CXX17()
1414
USE_OPENSSL()
1515

1616
IF ("${CMAKE_BUILD_TYPE}" STREQUAL "")
17-
set(CMAKE_BUILD_TYPE "Debug")
17+
SET (CMAKE_BUILD_TYPE "Debug")
1818
ENDIF()
1919

2020
IF (UNIX)
2121
IF (APPLE)
2222
SET (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -Wall -Wextra -Werror")
2323
ELSE ()
24-
SET (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -pthread -Wall -Wextra -Werror")
24+
SET (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -Wall -Wextra -Werror")
2525
ENDIF ()
2626
SET (CMAKE_EXE_LINKER_FLAGS, "${CMAKE_EXE_LINKER_FLAGS} -lpthread")
2727
ENDIF ()
2828

29-
INCLUDE_DIRECTORIES(.)
30-
INCLUDE_DIRECTORIES(contrib)
29+
INCLUDE_DIRECTORIES (.)
30+
INCLUDE_DIRECTORIES (contrib)
3131

3232
SUBDIRS (
3333
clickhouse

clickhouse/CMakeLists.txt

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
SET ( clickhouse-cpp-lib-src
2+
base/coded.cpp
23
base/compressed.cpp
34
base/input.cpp
45
base/output.cpp
56
base/platform.cpp
67
base/socket.cpp
7-
base/coded.cpp
88

99
columns/array.cpp
1010
columns/date.cpp
@@ -31,9 +31,9 @@ SET ( clickhouse-cpp-lib-src
3131
query.cpp
3232
)
3333

34-
if (WITH_OPENSSL)
34+
IF (WITH_OPENSSL)
3535
LIST(APPEND clickhouse-cpp-lib-src base/sslsocket.cpp)
36-
endif()
36+
ENDIF ()
3737

3838
ADD_LIBRARY (clickhouse-cpp-lib SHARED ${clickhouse-cpp-lib-src})
3939
SET_TARGET_PROPERTIES(clickhouse-cpp-lib PROPERTIES LINKER_LANGUAGE CXX)
@@ -56,7 +56,7 @@ IF (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
5656
TARGET_LINK_LIBRARIES (clickhouse-cpp-lib-static gcc_s)
5757
ENDIF ()
5858

59-
INSTALL(TARGETS clickhouse-cpp-lib clickhouse-cpp-lib-static
59+
INSTALL (TARGETS clickhouse-cpp-lib clickhouse-cpp-lib-static
6060
ARCHIVE DESTINATION lib
6161
LIBRARY DESTINATION lib
6262
)
@@ -71,6 +71,7 @@ INSTALL(FILES protocol.h DESTINATION include/clickhouse/)
7171
INSTALL(FILES query.h DESTINATION include/clickhouse/)
7272

7373
# base
74+
INSTALL(FILES base/coded.h DESTINATION include/clickhouse/base/)
7475
INSTALL(FILES base/buffer.h DESTINATION include/clickhouse/base/)
7576
INSTALL(FILES base/compressed.h DESTINATION include/clickhouse/base/)
7677
INSTALL(FILES base/input.h DESTINATION include/clickhouse/base/)
@@ -104,7 +105,7 @@ INSTALL(FILES columns/uuid.h DESTINATION include/clickhouse/columns/)
104105
INSTALL(FILES types/type_parser.h DESTINATION include/clickhouse/types/)
105106
INSTALL(FILES types/types.h DESTINATION include/clickhouse/types/)
106107

107-
if (WITH_OPENSSL)
108-
target_link_libraries(clickhouse-cpp-lib OpenSSL::SSL)
109-
target_link_libraries(clickhouse-cpp-lib-static OpenSSL::SSL)
110-
endif()
108+
IF (WITH_OPENSSL)
109+
TARGET_LINK_LIBRARIES (clickhouse-cpp-lib OpenSSL::SSL)
110+
TARGET_LINK_LIBRARIES (clickhouse-cpp-lib-static OpenSSL::SSL)
111+
ENDIF ()

clickhouse/base/socket.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ SOCKET SocketConnect(const NetworkAddress& addr) {
131131
} // namespace
132132

133133
NetworkAddress::NetworkAddress(const std::string& host, const std::string& port)
134-
: host_(host),
135-
info_(nullptr)
134+
: host_(host)
135+
, info_(nullptr)
136136
{
137137
struct addrinfo hints;
138138
memset(&hints, 0, sizeof(hints));

clickhouse/base/socket.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ struct addrinfo;
3030

3131
namespace clickhouse {
3232

33-
/**
33+
/** Address of a host to establish connection to.
3434
*
3535
*/
3636
class NetworkAddress {

clickhouse/base/sslsocket.cpp

Lines changed: 47 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,37 +7,43 @@
77
#include <openssl/err.h>
88
#include <openssl/asn1.h>
99

10-
#include <iostream>
1110

1211
namespace {
1312

1413
std::string getCertificateInfo(X509* cert)
1514
{
15+
if (!cert)
16+
return "No certificate";
17+
1618
std::unique_ptr<BIO, decltype(&BIO_free)> mem_bio(BIO_new(BIO_s_mem()), &BIO_free);
1719
X509_print(mem_bio.get(), cert);
20+
1821
char * data = nullptr;
19-
size_t len = BIO_get_mem_data(mem_bio.get(), &data);
22+
auto len = BIO_get_mem_data(mem_bio.get(), &data);
23+
if (len < 0)
24+
return "Can't get certificate info due to BIO error " + std::to_string(len);
2025

2126
return std::string(data, len);
2227
}
2328

24-
void throwSSLError(SSL * ssl, int error, const char * location, const char * statement) {
29+
void throwSSLError(SSL * ssl, int error, const char * /*location*/, const char * /*statement*/) {
2530
const auto detail_error = ERR_get_error();
2631
auto reason = ERR_reason_error_string(detail_error);
2732
reason = reason ? reason : "Unknown SSL error";
2833

2934
std::string reason_str = reason;
3035
if (ssl) {
31-
if (auto ssl_session = SSL_get_session(ssl))
32-
if (auto server_certificate = SSL_SESSION_get0_peer(ssl_session))
33-
reason_str += "\nServer certificate: " + getCertificateInfo(server_certificate);
36+
// TODO: maybe print certificate only if handshake isn't completed (SSL_get_state(ssl) != TLS_ST_OK)
37+
if (auto ssl_session = SSL_get_session(ssl)) {
38+
reason_str += "\nServer certificate: " + getCertificateInfo(SSL_SESSION_get0_peer(ssl_session));
39+
}
3440
}
3541

36-
std::cerr << "!!! SSL error at " << location
37-
<< "\n\tcaused by " << statement
38-
<< "\n\t: "<< reason_str << "(" << error << ")"
39-
<< "\n\t last err: " << ERR_peek_last_error()
40-
<< std::endl;
42+
// std::cerr << "!!! SSL error at " << location
43+
// << "\n\tcaused by " << statement
44+
// << "\n\t: "<< reason_str << "(" << error << ")"
45+
// << "\n\t last err: " << ERR_peek_last_error()
46+
// << std::endl;
4147

4248
throw std::runtime_error(std::string("OpenSSL error: ") + std::to_string(error) + " : " + reason_str);
4349
}
@@ -64,8 +70,8 @@ SSL_CTX * prepareSSLContext(const clickhouse::SSLParams & context_params) {
6470
throw std::runtime_error("Failed to initialize SSL context");
6571

6672
#define HANDLE_SSL_CTX_ERROR(statement) do { \
67-
if (const auto ret_code = statement; !ret_code) \
68-
throwSSLError(nullptr, ERR_peek_error(), LOCATION, STRINGIFY(statement)); \
73+
if (const auto ret_code = (statement); !ret_code) \
74+
throwSSLError(nullptr, ERR_peek_error(), LOCATION, #statement); \
6975
} while(false);
7076

7177
if (context_params.use_default_ca_locations)
@@ -91,47 +97,48 @@ SSL_CTX * prepareSSLContext(const clickhouse::SSLParams & context_params) {
9197
SSL_CTX_set_max_proto_version(ctx.get(), context_params.max_protocol_version));
9298

9399
return ctx.release();
100+
#undef HANDLE_SSL_CTX_ERROR
94101
}
95102

96-
97-
98103
}
99104

100-
#define HANDLE_SSL_ERROR(statement) do { \
101-
if (const auto ret_code = statement; ret_code <= 0) \
102-
throwSSLError(ssl_, SSL_get_error(ssl_, ret_code), LOCATION, STRINGIFY(statement)); \
103-
} while(false);
104-
105105
namespace clickhouse {
106106

107107
SSLContext::SSLContext(SSL_CTX & context)
108-
: context_(&context)
108+
: context_(&context, &SSL_CTX_free)
109109
{
110-
SSL_CTX_up_ref(context_);
110+
SSL_CTX_up_ref(context_.get());
111111
}
112112

113113
SSLContext::SSLContext(const SSLParams & context_params)
114-
: context_(prepareSSLContext(context_params))
114+
: context_(prepareSSLContext(context_params), &SSL_CTX_free)
115115
{
116116
}
117117

118-
SSLContext::~SSLContext() {
119-
SSL_CTX_free(context_);
120-
}
121-
122118
SSL_CTX * SSLContext::getContext() {
123-
return context_;
119+
return context_.get();
124120
}
125121

122+
// Allows caller to use returned value of `statement` if there was no error, throws exception otherwise.
123+
#define HANDLE_SSL_ERROR(statement) [&] { \
124+
if (const auto ret_code = (statement); ret_code <= 0) { \
125+
throwSSLError(ssl_, SSL_get_error(ssl_, ret_code), LOCATION, #statement); \
126+
return static_cast<decltype(ret_code)>(0); \
127+
} \
128+
else \
129+
return ret_code; \
130+
}()
131+
126132
/* // debug macro for tracing SSL state
127133
#define LOG_SSL_STATE() std::cerr << "!!!!" << LOCATION << " @" << __FUNCTION__ \
128134
<< "\t" << SSL_get_version(ssl_) << " state: " << SSL_state_string_long(ssl_) \
129135
<< "\n\t handshake state: " << SSL_get_state(ssl_) \
130136
<< std::endl
131137
*/
132138
SSLSocket::SSLSocket(const NetworkAddress& addr, const SSLParams & ssl_params, SSLContext& context)
133-
: Socket(addr),
134-
ssl_(SSL_new(context.getContext()))
139+
: Socket(addr)
140+
, ssl_ptr_(SSL_new(context.getContext()), &SSL_free)
141+
, ssl_(ssl_ptr_.get())
135142
{
136143
if (!ssl_)
137144
throw std::runtime_error("Failed to create SSL instance");
@@ -143,41 +150,37 @@ SSLSocket::SSLSocket(const NetworkAddress& addr, const SSLParams & ssl_params, S
143150
SSL_set_connect_state(ssl_);
144151
HANDLE_SSL_ERROR(SSL_connect(ssl_));
145152
HANDLE_SSL_ERROR(SSL_set_mode(ssl_, SSL_MODE_AUTO_RETRY));
153+
auto peer_certificate = SSL_get_peer_certificate(ssl_);
146154

147-
if(const auto verify_result = SSL_get_verify_result(ssl_); verify_result != X509_V_OK) {
148-
auto error_message = X509_verify_cert_error_string(verify_result);
149-
auto ssl_session = SSL_get_session(ssl_);
150-
auto cert = SSL_SESSION_get0_peer(ssl_session);
155+
if (!peer_certificate)
156+
throw std::runtime_error("Failed to verify SSL connection: server provided no ceritificate.");
151157

158+
if (const auto verify_result = SSL_get_verify_result(ssl_); verify_result != X509_V_OK) {
159+
auto error_message = X509_verify_cert_error_string(verify_result);
152160
throw std::runtime_error("Failed to verify SSL connection, X509_v error: "
153-
+ std::to_string(verify_result)
154-
+ " " + error_message + "\n" + getCertificateInfo(cert));
161+
+ std::to_string(verify_result)
162+
+ " " + error_message
163+
+ "\nServer certificate: " + getCertificateInfo(peer_certificate));
155164
}
156165

157166
if (ssl_params.use_SNI) {
158-
auto ssl_session = SSL_get_session(ssl_);
159-
auto peer_cert = SSL_SESSION_get0_peer(ssl_session);
160167
auto hostname = addr.Host();
161168
char * out_name = nullptr;
162169

163170
std::unique_ptr<ASN1_OCTET_STRING, decltype(&ASN1_OCTET_STRING_free)> addr(a2i_IPADDRESS(hostname.c_str()), &ASN1_OCTET_STRING_free);
164171
if (addr) {
165172
// if hostname is actually an IP address
166173
HANDLE_SSL_ERROR(X509_check_ip(
167-
peer_cert,
174+
peer_certificate,
168175
ASN1_STRING_get0_data(addr.get()),
169176
ASN1_STRING_length(addr.get()),
170177
0));
171178
} else {
172-
HANDLE_SSL_ERROR(X509_check_host(peer_cert, hostname.c_str(), hostname.length(), 0, &out_name));
179+
HANDLE_SSL_ERROR(X509_check_host(peer_certificate, hostname.c_str(), hostname.length(), 0, &out_name));
173180
}
174181
}
175182
}
176183

177-
SSLSocket::~SSLSocket() {
178-
SSL_free(ssl_);
179-
}
180-
181184
std::unique_ptr<InputStream> SSLSocket::makeInputStream() const {
182185
return std::make_unique<SSLSocketInput>(ssl_);
183186
}
@@ -190,8 +193,6 @@ SSLSocketInput::SSLSocketInput(SSL *ssl)
190193
: ssl_(ssl)
191194
{}
192195

193-
SSLSocketInput::~SSLSocketInput() = default;
194-
195196
size_t SSLSocketInput::DoRead(void* buf, size_t len) {
196197
size_t actually_read;
197198
HANDLE_SSL_ERROR(SSL_read_ex(ssl_, buf, len, &actually_read));
@@ -202,8 +203,6 @@ SSLSocketOutput::SSLSocketOutput(SSL *ssl)
202203
: ssl_(ssl)
203204
{}
204205

205-
SSLSocketOutput::~SSLSocketOutput() = default;
206-
207206
void SSLSocketOutput::DoWrite(const void* data, size_t len) {
208207
HANDLE_SSL_ERROR(SSL_write(ssl_, data, len));
209208
}

clickhouse/base/sslsocket.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include "socket.h"
44

5+
#include <memory>
6+
57
typedef struct ssl_ctx_st SSL_CTX;
68
typedef struct ssl_st SSL;
79

@@ -23,7 +25,7 @@ class SSLContext
2325
public:
2426
explicit SSLContext(SSL_CTX & context);
2527
explicit SSLContext(const SSLParams & context_params);
26-
~SSLContext();
28+
~SSLContext() = default;
2729

2830
SSLContext(const SSLContext &) = delete;
2931
SSLContext& operator=(const SSLContext &) = delete;
@@ -35,14 +37,14 @@ class SSLContext
3537
SSL_CTX * getContext();
3638

3739
private:
38-
SSL_CTX * const context_;
40+
std::unique_ptr<SSL_CTX, void (*)(SSL_CTX*)> context_;
3941
};
4042

4143
class SSLSocket : public Socket {
4244
public:
4345
explicit SSLSocket(const NetworkAddress& addr, const SSLParams & ssl_params, SSLContext& context);
4446
SSLSocket(SSLSocket &&) = default;
45-
~SSLSocket();
47+
~SSLSocket() = default;
4648

4749
SSLSocket(const SSLSocket & ) = delete;
4850
SSLSocket& operator=(const SSLSocket & ) = delete;
@@ -51,30 +53,33 @@ class SSLSocket : public Socket {
5153
std::unique_ptr<OutputStream> makeOutputStream() const override;
5254

5355
private:
54-
SSL *ssl_;
56+
std::unique_ptr<SSL, void (*)(SSL *s)> ssl_ptr_;
57+
SSL *ssl_; // for convinience with SSL API
5558
};
5659

5760
class SSLSocketInput : public InputStream {
5861
public:
5962
explicit SSLSocketInput(SSL *ssl);
60-
~SSLSocketInput();
63+
~SSLSocketInput() = default;
6164

6265
protected:
6366
size_t DoRead(void* buf, size_t len) override;
6467

6568
private:
69+
// Not owning
6670
SSL *ssl_;
6771
};
6872

6973
class SSLSocketOutput : public OutputStream {
7074
public:
7175
explicit SSLSocketOutput(SSL *ssl);
72-
~SSLSocketOutput();
76+
~SSLSocketOutput() = default;
7377

7478
protected:
7579
void DoWrite(const void* data, size_t len) override;
7680

7781
private:
82+
// Not owning
7883
SSL *ssl_;
7984
};
8085

0 commit comments

Comments
 (0)