7
7
#include < openssl/err.h>
8
8
#include < openssl/asn1.h>
9
9
10
- #include < iostream>
11
10
12
11
namespace {
13
12
14
13
std::string getCertificateInfo (X509* cert)
15
14
{
15
+ if (!cert)
16
+ return " No certificate" ;
17
+
16
18
std::unique_ptr<BIO, decltype (&BIO_free)> mem_bio (BIO_new (BIO_s_mem ()), &BIO_free);
17
19
X509_print (mem_bio.get (), cert);
20
+
18
21
char * data = nullptr ;
19
- size_t len = BIO_get_mem_data (mem_bio.get (), &data);
22
+ auto len = BIO_get_mem_data (mem_bio.get (), &data);
23
+ if (len < 0 )
24
+ return " Can't get certificate info due to BIO error " + std::to_string (len);
20
25
21
26
return std::string (data, len);
22
27
}
23
28
24
- void throwSSLError (SSL * ssl, int error, const char * location, const char * statement) {
29
+ void throwSSLError (SSL * ssl, int error, const char * /* location*/ , const char * /* statement*/ ) {
25
30
const auto detail_error = ERR_get_error ();
26
31
auto reason = ERR_reason_error_string (detail_error);
27
32
reason = reason ? reason : " Unknown SSL error" ;
28
33
29
34
std::string reason_str = reason;
30
35
if (ssl) {
31
- if (auto ssl_session = SSL_get_session (ssl))
32
- if (auto server_certificate = SSL_SESSION_get0_peer (ssl_session))
33
- reason_str += " \n Server certificate: " + getCertificateInfo (server_certificate);
36
+ // TODO: maybe print certificate only if handshake isn't completed (SSL_get_state(ssl) != TLS_ST_OK)
37
+ if (auto ssl_session = SSL_get_session (ssl)) {
38
+ reason_str += " \n Server certificate: " + getCertificateInfo (SSL_SESSION_get0_peer (ssl_session));
39
+ }
34
40
}
35
41
36
- std::cerr << " !!! SSL error at " << location
37
- << " \n\t caused by " << statement
38
- << " \n\t : " << reason_str << " (" << error << " )"
39
- << " \n\t last err: " << ERR_peek_last_error ()
40
- << std::endl;
42
+ // std::cerr << "!!! SSL error at " << location
43
+ // << "\n\tcaused by " << statement
44
+ // << "\n\t: "<< reason_str << "(" << error << ")"
45
+ // << "\n\t last err: " << ERR_peek_last_error()
46
+ // << std::endl;
41
47
42
48
throw std::runtime_error (std::string (" OpenSSL error: " ) + std::to_string (error) + " : " + reason_str);
43
49
}
@@ -64,8 +70,8 @@ SSL_CTX * prepareSSLContext(const clickhouse::SSLParams & context_params) {
64
70
throw std::runtime_error (" Failed to initialize SSL context" );
65
71
66
72
#define HANDLE_SSL_CTX_ERROR (statement ) do { \
67
- if (const auto ret_code = statement; !ret_code) \
68
- throwSSLError (nullptr , ERR_peek_error (), LOCATION, STRINGIFY ( statement) ); \
73
+ if (const auto ret_code = ( statement) ; !ret_code) \
74
+ throwSSLError (nullptr , ERR_peek_error (), LOCATION, # statement); \
69
75
} while (false );
70
76
71
77
if (context_params.use_default_ca_locations )
@@ -91,47 +97,48 @@ SSL_CTX * prepareSSLContext(const clickhouse::SSLParams & context_params) {
91
97
SSL_CTX_set_max_proto_version (ctx.get (), context_params.max_protocol_version ));
92
98
93
99
return ctx.release ();
100
+ #undef HANDLE_SSL_CTX_ERROR
94
101
}
95
102
96
-
97
-
98
103
}
99
104
100
- #define HANDLE_SSL_ERROR (statement ) do { \
101
- if (const auto ret_code = statement; ret_code <= 0 ) \
102
- throwSSLError (ssl_, SSL_get_error (ssl_, ret_code), LOCATION, STRINGIFY (statement)); \
103
- } while (false );
104
-
105
105
namespace clickhouse {
106
106
107
107
SSLContext::SSLContext (SSL_CTX & context)
108
- : context_(&context)
108
+ : context_(&context, &SSL_CTX_free )
109
109
{
110
- SSL_CTX_up_ref (context_);
110
+ SSL_CTX_up_ref (context_. get () );
111
111
}
112
112
113
113
SSLContext::SSLContext (const SSLParams & context_params)
114
- : context_(prepareSSLContext(context_params))
114
+ : context_(prepareSSLContext(context_params), &SSL_CTX_free )
115
115
{
116
116
}
117
117
118
- SSLContext::~SSLContext () {
119
- SSL_CTX_free (context_);
120
- }
121
-
122
118
SSL_CTX * SSLContext::getContext () {
123
- return context_;
119
+ return context_. get () ;
124
120
}
125
121
122
+ // Allows caller to use returned value of `statement` if there was no error, throws exception otherwise.
123
+ #define HANDLE_SSL_ERROR (statement ) [&] { \
124
+ if (const auto ret_code = (statement); ret_code <= 0 ) { \
125
+ throwSSLError (ssl_, SSL_get_error (ssl_, ret_code), LOCATION, #statement); \
126
+ return static_cast <decltype (ret_code)>(0 ); \
127
+ } \
128
+ else \
129
+ return ret_code; \
130
+ }()
131
+
126
132
/* // debug macro for tracing SSL state
127
133
#define LOG_SSL_STATE() std::cerr << "!!!!" << LOCATION << " @" << __FUNCTION__ \
128
134
<< "\t" << SSL_get_version(ssl_) << " state: " << SSL_state_string_long(ssl_) \
129
135
<< "\n\t handshake state: " << SSL_get_state(ssl_) \
130
136
<< std::endl
131
137
*/
132
138
SSLSocket::SSLSocket (const NetworkAddress& addr, const SSLParams & ssl_params, SSLContext& context)
133
- : Socket(addr),
134
- ssl_ (SSL_new(context.getContext()))
139
+ : Socket(addr)
140
+ , ssl_ptr_(SSL_new(context.getContext()), &SSL_free)
141
+ , ssl_(ssl_ptr_.get())
135
142
{
136
143
if (!ssl_)
137
144
throw std::runtime_error (" Failed to create SSL instance" );
@@ -143,41 +150,37 @@ SSLSocket::SSLSocket(const NetworkAddress& addr, const SSLParams & ssl_params, S
143
150
SSL_set_connect_state (ssl_);
144
151
HANDLE_SSL_ERROR (SSL_connect (ssl_));
145
152
HANDLE_SSL_ERROR (SSL_set_mode (ssl_, SSL_MODE_AUTO_RETRY));
153
+ auto peer_certificate = SSL_get_peer_certificate (ssl_);
146
154
147
- if (const auto verify_result = SSL_get_verify_result (ssl_); verify_result != X509_V_OK) {
148
- auto error_message = X509_verify_cert_error_string (verify_result);
149
- auto ssl_session = SSL_get_session (ssl_);
150
- auto cert = SSL_SESSION_get0_peer (ssl_session);
155
+ if (!peer_certificate)
156
+ throw std::runtime_error (" Failed to verify SSL connection: server provided no ceritificate." );
151
157
158
+ if (const auto verify_result = SSL_get_verify_result (ssl_); verify_result != X509_V_OK) {
159
+ auto error_message = X509_verify_cert_error_string (verify_result);
152
160
throw std::runtime_error (" Failed to verify SSL connection, X509_v error: "
153
- + std::to_string (verify_result)
154
- + " " + error_message + " \n " + getCertificateInfo (cert));
161
+ + std::to_string (verify_result)
162
+ + " " + error_message
163
+ + " \n Server certificate: " + getCertificateInfo (peer_certificate));
155
164
}
156
165
157
166
if (ssl_params.use_SNI ) {
158
- auto ssl_session = SSL_get_session (ssl_);
159
- auto peer_cert = SSL_SESSION_get0_peer (ssl_session);
160
167
auto hostname = addr.Host ();
161
168
char * out_name = nullptr ;
162
169
163
170
std::unique_ptr<ASN1_OCTET_STRING, decltype (&ASN1_OCTET_STRING_free)> addr (a2i_IPADDRESS (hostname.c_str ()), &ASN1_OCTET_STRING_free);
164
171
if (addr) {
165
172
// if hostname is actually an IP address
166
173
HANDLE_SSL_ERROR (X509_check_ip (
167
- peer_cert ,
174
+ peer_certificate ,
168
175
ASN1_STRING_get0_data (addr.get ()),
169
176
ASN1_STRING_length (addr.get ()),
170
177
0 ));
171
178
} else {
172
- HANDLE_SSL_ERROR (X509_check_host (peer_cert , hostname.c_str (), hostname.length (), 0 , &out_name));
179
+ HANDLE_SSL_ERROR (X509_check_host (peer_certificate , hostname.c_str (), hostname.length (), 0 , &out_name));
173
180
}
174
181
}
175
182
}
176
183
177
- SSLSocket::~SSLSocket () {
178
- SSL_free (ssl_);
179
- }
180
-
181
184
std::unique_ptr<InputStream> SSLSocket::makeInputStream () const {
182
185
return std::make_unique<SSLSocketInput>(ssl_);
183
186
}
@@ -190,8 +193,6 @@ SSLSocketInput::SSLSocketInput(SSL *ssl)
190
193
: ssl_(ssl)
191
194
{}
192
195
193
- SSLSocketInput::~SSLSocketInput () = default ;
194
-
195
196
size_t SSLSocketInput::DoRead (void * buf, size_t len) {
196
197
size_t actually_read;
197
198
HANDLE_SSL_ERROR (SSL_read_ex (ssl_, buf, len, &actually_read));
@@ -202,8 +203,6 @@ SSLSocketOutput::SSLSocketOutput(SSL *ssl)
202
203
: ssl_(ssl)
203
204
{}
204
205
205
- SSLSocketOutput::~SSLSocketOutput () = default ;
206
-
207
206
void SSLSocketOutput::DoWrite (const void * data, size_t len) {
208
207
HANDLE_SSL_ERROR (SSL_write (ssl_, data, len));
209
208
}
0 commit comments