tcp read and write handling of write events in netevent for tcp and ssl.

This commit is contained in:
W.C.A. Wijngaards 2020-06-26 16:05:15 +02:00
parent 64c8d18814
commit cfe009a31c
2 changed files with 131 additions and 50 deletions

View file

@ -992,11 +992,12 @@ static void
tcp_callback_writer(struct comm_point* c)
{
log_assert(c->type == comm_tcp);
sldns_buffer_clear(c->buffer);
if(!c->tcp_write_and_read) {
sldns_buffer_clear(c->buffer);
c->tcp_byte_count = 0;
}
if(c->tcp_do_toggle_rw)
c->tcp_is_reading = 1;
if(!c->tcp_write_and_read)
c->tcp_byte_count = 0;
/* switch from listening(write) to listening(read) */
if(c->tcp_req_info) {
tcp_req_info_handle_writedone(c->tcp_req_info);
@ -1302,10 +1303,28 @@ ssl_handle_write(struct comm_point* c)
}
/* ignore return, if fails we may simply block */
(void)SSL_set_mode(c->ssl, (long)SSL_MODE_ENABLE_PARTIAL_WRITE);
if(c->tcp_byte_count < sizeof(uint16_t)) {
uint16_t len = htons(sldns_buffer_limit(c->buffer));
if((c->tcp_write_and_read?c->tcp_write_byte_count:c->tcp_byte_count) < sizeof(uint16_t)) {
uint16_t len = htons(c->tcp_write_and_read?c->tcp_write_pkt_len:sldns_buffer_limit(c->buffer));
ERR_clear_error();
if(sizeof(uint16_t)+sldns_buffer_remaining(c->buffer) <
if(c->tcp_write_and_read) {
if(c->tcp_write_pkt_len + 2 < LDNS_RR_BUF_SIZE) {
/* combine the tcp length and the query for
* write, this emulates writev */
uint8_t buf[LDNS_RR_BUF_SIZE];
memmove(buf, &len, sizeof(uint16_t));
memmove(buf+sizeof(uint16_t),
c->tcp_write_pkt,
c->tcp_write_pkt_len);
r = SSL_write(c->ssl,
(void*)(buf+c->tcp_write_byte_count),
c->tcp_write_pkt_len + 2 -
c->tcp_write_byte_count);
} else {
r = SSL_write(c->ssl,
(void*)(((uint8_t*)&len)+c->tcp_write_byte_count),
(int)(sizeof(uint16_t)-c->tcp_write_byte_count));
}
} else if(sizeof(uint16_t)+sldns_buffer_remaining(c->buffer) <
LDNS_RR_BUF_SIZE) {
/* combine the tcp length and the query for write,
* this emulates writev */
@ -1347,20 +1366,32 @@ ssl_handle_write(struct comm_point* c)
log_crypto_err("could not SSL_write");
return 0;
}
c->tcp_byte_count += r;
if(c->tcp_byte_count < sizeof(uint16_t))
return 1;
sldns_buffer_set_position(c->buffer, c->tcp_byte_count -
sizeof(uint16_t));
if(sldns_buffer_remaining(c->buffer) == 0) {
if(c->tcp_write_and_read) {
c->tcp_write_byte_count += r;
if(c->tcp_write_byte_count < sizeof(uint16_t))
return 1;
} else {
c->tcp_byte_count += r;
if(c->tcp_byte_count < sizeof(uint16_t))
return 1;
sldns_buffer_set_position(c->buffer, c->tcp_byte_count -
sizeof(uint16_t));
}
if((!c->tcp_write_and_read && sldns_buffer_remaining(c->buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) {
tcp_callback_writer(c);
return 1;
}
}
log_assert(sldns_buffer_remaining(c->buffer) > 0);
log_assert(c->tcp_write_and_read || sldns_buffer_remaining(c->buffer) > 0);
log_assert(!c->tcp_write_and_read || c->tcp_write_byte_count < c->tcp_write_pkt_len + 2);
ERR_clear_error();
r = SSL_write(c->ssl, (void*)sldns_buffer_current(c->buffer),
(int)sldns_buffer_remaining(c->buffer));
if(c->tcp_write_and_read) {
r = SSL_write(c->ssl, (void*)(c->tcp_write_pkt + c->tcp_write_byte_count - 2),
(int)(c->tcp_write_pkt_len + 2 - c->tcp_write_byte_count));
} else {
r = SSL_write(c->ssl, (void*)sldns_buffer_current(c->buffer),
(int)sldns_buffer_remaining(c->buffer));
}
if(r <= 0) {
int want = SSL_get_error(c->ssl, r);
if(want == SSL_ERROR_ZERO_RETURN) {
@ -1385,9 +1416,13 @@ ssl_handle_write(struct comm_point* c)
log_crypto_err("could not SSL_write");
return 0;
}
sldns_buffer_skip(c->buffer, (ssize_t)r);
if(c->tcp_write_and_read) {
c->tcp_write_byte_count += r;
} else {
sldns_buffer_skip(c->buffer, (ssize_t)r);
}
if(sldns_buffer_remaining(c->buffer) == 0) {
if((!c->tcp_write_and_read && sldns_buffer_remaining(c->buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) {
tcp_callback_writer(c);
}
return 1;
@ -1531,7 +1566,7 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c)
if(c->tcp_is_reading && !c->ssl)
return 0;
log_assert(fd != -1);
if(c->tcp_byte_count == 0 && c->tcp_check_nb_connect) {
if(((!c->tcp_write_and_read && c->tcp_byte_count == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == 0)) && c->tcp_check_nb_connect) {
/* check for pending error from nonblocking connect */
/* from Stevens, unix network programming, vol1, 3rd ed, p450*/
int error = 0;
@ -1581,15 +1616,22 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c)
if(c->tcp_do_fastopen == 1) {
/* this form of sendmsg() does both a connect() and send() so need to
look for various flavours of error*/
uint16_t len = htons(sldns_buffer_limit(buffer));
uint16_t len = htons(c->tcp_write_and_read?c->tcp_write_pkt_len:sldns_buffer_limit(buffer));
struct msghdr msg;
struct iovec iov[2];
c->tcp_do_fastopen = 0;
memset(&msg, 0, sizeof(msg));
iov[0].iov_base = (uint8_t*)&len + c->tcp_byte_count;
iov[0].iov_len = sizeof(uint16_t) - c->tcp_byte_count;
iov[1].iov_base = sldns_buffer_begin(buffer);
iov[1].iov_len = sldns_buffer_limit(buffer);
if(c->tcp_write_and_read) {
iov[0].iov_base = (uint8_t*)&len + c->tcp_write_byte_count;
iov[0].iov_len = sizeof(uint16_t) - c->tcp_write_byte_count;
iov[1].iov_base = c->tcp_write_pkt;
iov[1].iov_len = c->tcp_write_pkt_len;
} else {
iov[0].iov_base = (uint8_t*)&len + c->tcp_byte_count;
iov[0].iov_len = sizeof(uint16_t) - c->tcp_byte_count;
iov[1].iov_base = sldns_buffer_begin(buffer);
iov[1].iov_len = sldns_buffer_limit(buffer);
}
log_assert(iov[0].iov_len > 0);
msg.msg_name = &c->repinfo.addr;
msg.msg_namelen = c->repinfo.addrlen;
@ -1635,12 +1677,18 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c)
}
} else {
c->tcp_byte_count += r;
if(c->tcp_byte_count < sizeof(uint16_t))
return 1;
sldns_buffer_set_position(buffer, c->tcp_byte_count -
sizeof(uint16_t));
if(sldns_buffer_remaining(buffer) == 0) {
if(c->tcp_write_and_read) {
c->tcp_write_byte_count += r;
if(c->tcp_write_byte_count < sizeof(uint16_t))
return 1;
} else {
c->tcp_byte_count += r;
if(c->tcp_byte_count < sizeof(uint16_t))
return 1;
sldns_buffer_set_position(buffer, c->tcp_byte_count -
sizeof(uint16_t));
}
if((!c->tcp_write_and_read && sldns_buffer_remaining(buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) {
tcp_callback_writer(c);
return 1;
}
@ -1648,19 +1696,31 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c)
}
#endif /* USE_MSG_FASTOPEN */
if(c->tcp_byte_count < sizeof(uint16_t)) {
uint16_t len = htons(sldns_buffer_limit(buffer));
if((c->tcp_write_and_read?c->tcp_write_byte_count:c->tcp_byte_count) < sizeof(uint16_t)) {
uint16_t len = htons(c->tcp_write_and_read?c->tcp_write_pkt_len:sldns_buffer_limit(buffer));
#ifdef HAVE_WRITEV
struct iovec iov[2];
iov[0].iov_base = (uint8_t*)&len + c->tcp_byte_count;
iov[0].iov_len = sizeof(uint16_t) - c->tcp_byte_count;
iov[1].iov_base = sldns_buffer_begin(buffer);
iov[1].iov_len = sldns_buffer_limit(buffer);
if(c->tcp_write_and_read) {
iov[0].iov_base = (uint8_t*)&len + c->tcp_write_byte_count;
iov[0].iov_len = sizeof(uint16_t) - c->tcp_write_byte_count;
iov[1].iov_base = c->tcp_write_pkt;
iov[1].iov_len = c->tcp_write_pkt_len;
} else {
iov[0].iov_base = (uint8_t*)&len + c->tcp_byte_count;
iov[0].iov_len = sizeof(uint16_t) - c->tcp_byte_count;
iov[1].iov_base = sldns_buffer_begin(buffer);
iov[1].iov_len = sldns_buffer_limit(buffer);
}
log_assert(iov[0].iov_len > 0);
r = writev(fd, iov, 2);
#else /* HAVE_WRITEV */
r = send(fd, (void*)(((uint8_t*)&len)+c->tcp_byte_count),
sizeof(uint16_t)-c->tcp_byte_count, 0);
if(c->tcp_write_and_read) {
r = send(fd, (void*)(((uint8_t*)&len)+c->tcp_write_byte_count),
sizeof(uint16_t)-c->tcp_write_byte_count, 0);
} else {
r = send(fd, (void*)(((uint8_t*)&len)+c->tcp_byte_count),
sizeof(uint16_t)-c->tcp_byte_count, 0);
}
#endif /* HAVE_WRITEV */
if(r == -1) {
#ifndef USE_WINSOCK
@ -1699,19 +1759,31 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c)
#endif
return 0;
}
c->tcp_byte_count += r;
if(c->tcp_byte_count < sizeof(uint16_t))
return 1;
sldns_buffer_set_position(buffer, c->tcp_byte_count -
sizeof(uint16_t));
if(sldns_buffer_remaining(buffer) == 0) {
if(c->tcp_write_and_read) {
c->tcp_write_byte_count += r;
if(c->tcp_write_byte_count < sizeof(uint16_t))
return 1;
} else {
c->tcp_byte_count += r;
if(c->tcp_byte_count < sizeof(uint16_t))
return 1;
sldns_buffer_set_position(buffer, c->tcp_byte_count -
sizeof(uint16_t));
}
if((!c->tcp_write_and_read && sldns_buffer_remaining(buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) {
tcp_callback_writer(c);
return 1;
}
}
log_assert(sldns_buffer_remaining(buffer) > 0);
r = send(fd, (void*)sldns_buffer_current(buffer),
sldns_buffer_remaining(buffer), 0);
log_assert(c->tcp_write_and_read || sldns_buffer_remaining(buffer) > 0);
log_assert(!c->tcp_write_and_read || c->tcp_write_byte_count < c->tcp_write_pkt_len + 2);
if(c->tcp_write_and_read) {
r = send(fd, (void*)c->tcp_write_pkt + c->tcp_write_byte_count - 2,
c->tcp_write_pkt_len + 2 - c->tcp_write_byte_count, 0);
} else {
r = send(fd, (void*)sldns_buffer_current(buffer),
sldns_buffer_remaining(buffer), 0);
}
if(r == -1) {
#ifndef USE_WINSOCK
if(errno == EINTR || errno == EAGAIN)
@ -1736,9 +1808,13 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c)
#endif
return 0;
}
sldns_buffer_skip(buffer, r);
if(c->tcp_write_and_read) {
c->tcp_write_byte_count += r;
} else {
sldns_buffer_skip(buffer, r);
}
if(sldns_buffer_remaining(buffer) == 0) {
if((!c->tcp_write_and_read && sldns_buffer_remaining(buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) {
tcp_callback_writer(c);
}

View file

@ -254,11 +254,16 @@ struct comm_point {
int tcp_write_and_read;
/** byte count for written length over write channel, for when
* tcp_write_and_read is enabled */
* tcp_write_and_read is enabled. When tcp_write_and_read is enabled,
* this is the counter for writing, the one for reading is in the
* commpoint.buffer sldns buffer. The counter counts from 0 to
* 2+tcp_write_pkt_len, and includes the tcp length bytes. */
size_t tcp_write_byte_count;
/** packet to write currently over the write channel. for when
* tcp_write_and_read is enabled */
* tcp_write_and_read is enabled. When tcp_write_and_read is enabled,
* this is the buffer for the written packet, the commpoint.buffer
* sldns buffer is the buffer for the received packet. */
uint8_t* tcp_write_pkt;
/** length of tcp_write_pkt in bytes */
size_t tcp_write_pkt_len;