mirror of
https://github.com/NLnetLabs/unbound.git
synced 2025-12-26 09:39:43 -05:00
tcp read and write handling of write events in netevent for tcp and ssl.
This commit is contained in:
parent
64c8d18814
commit
cfe009a31c
2 changed files with 131 additions and 50 deletions
172
util/netevent.c
172
util/netevent.c
|
|
@ -992,11 +992,12 @@ static void
|
|||
tcp_callback_writer(struct comm_point* c)
|
||||
{
|
||||
log_assert(c->type == comm_tcp);
|
||||
sldns_buffer_clear(c->buffer);
|
||||
if(!c->tcp_write_and_read) {
|
||||
sldns_buffer_clear(c->buffer);
|
||||
c->tcp_byte_count = 0;
|
||||
}
|
||||
if(c->tcp_do_toggle_rw)
|
||||
c->tcp_is_reading = 1;
|
||||
if(!c->tcp_write_and_read)
|
||||
c->tcp_byte_count = 0;
|
||||
/* switch from listening(write) to listening(read) */
|
||||
if(c->tcp_req_info) {
|
||||
tcp_req_info_handle_writedone(c->tcp_req_info);
|
||||
|
|
@ -1302,10 +1303,28 @@ ssl_handle_write(struct comm_point* c)
|
|||
}
|
||||
/* ignore return, if fails we may simply block */
|
||||
(void)SSL_set_mode(c->ssl, (long)SSL_MODE_ENABLE_PARTIAL_WRITE);
|
||||
if(c->tcp_byte_count < sizeof(uint16_t)) {
|
||||
uint16_t len = htons(sldns_buffer_limit(c->buffer));
|
||||
if((c->tcp_write_and_read?c->tcp_write_byte_count:c->tcp_byte_count) < sizeof(uint16_t)) {
|
||||
uint16_t len = htons(c->tcp_write_and_read?c->tcp_write_pkt_len:sldns_buffer_limit(c->buffer));
|
||||
ERR_clear_error();
|
||||
if(sizeof(uint16_t)+sldns_buffer_remaining(c->buffer) <
|
||||
if(c->tcp_write_and_read) {
|
||||
if(c->tcp_write_pkt_len + 2 < LDNS_RR_BUF_SIZE) {
|
||||
/* combine the tcp length and the query for
|
||||
* write, this emulates writev */
|
||||
uint8_t buf[LDNS_RR_BUF_SIZE];
|
||||
memmove(buf, &len, sizeof(uint16_t));
|
||||
memmove(buf+sizeof(uint16_t),
|
||||
c->tcp_write_pkt,
|
||||
c->tcp_write_pkt_len);
|
||||
r = SSL_write(c->ssl,
|
||||
(void*)(buf+c->tcp_write_byte_count),
|
||||
c->tcp_write_pkt_len + 2 -
|
||||
c->tcp_write_byte_count);
|
||||
} else {
|
||||
r = SSL_write(c->ssl,
|
||||
(void*)(((uint8_t*)&len)+c->tcp_write_byte_count),
|
||||
(int)(sizeof(uint16_t)-c->tcp_write_byte_count));
|
||||
}
|
||||
} else if(sizeof(uint16_t)+sldns_buffer_remaining(c->buffer) <
|
||||
LDNS_RR_BUF_SIZE) {
|
||||
/* combine the tcp length and the query for write,
|
||||
* this emulates writev */
|
||||
|
|
@ -1347,20 +1366,32 @@ ssl_handle_write(struct comm_point* c)
|
|||
log_crypto_err("could not SSL_write");
|
||||
return 0;
|
||||
}
|
||||
c->tcp_byte_count += r;
|
||||
if(c->tcp_byte_count < sizeof(uint16_t))
|
||||
return 1;
|
||||
sldns_buffer_set_position(c->buffer, c->tcp_byte_count -
|
||||
sizeof(uint16_t));
|
||||
if(sldns_buffer_remaining(c->buffer) == 0) {
|
||||
if(c->tcp_write_and_read) {
|
||||
c->tcp_write_byte_count += r;
|
||||
if(c->tcp_write_byte_count < sizeof(uint16_t))
|
||||
return 1;
|
||||
} else {
|
||||
c->tcp_byte_count += r;
|
||||
if(c->tcp_byte_count < sizeof(uint16_t))
|
||||
return 1;
|
||||
sldns_buffer_set_position(c->buffer, c->tcp_byte_count -
|
||||
sizeof(uint16_t));
|
||||
}
|
||||
if((!c->tcp_write_and_read && sldns_buffer_remaining(c->buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) {
|
||||
tcp_callback_writer(c);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
log_assert(sldns_buffer_remaining(c->buffer) > 0);
|
||||
log_assert(c->tcp_write_and_read || sldns_buffer_remaining(c->buffer) > 0);
|
||||
log_assert(!c->tcp_write_and_read || c->tcp_write_byte_count < c->tcp_write_pkt_len + 2);
|
||||
ERR_clear_error();
|
||||
r = SSL_write(c->ssl, (void*)sldns_buffer_current(c->buffer),
|
||||
(int)sldns_buffer_remaining(c->buffer));
|
||||
if(c->tcp_write_and_read) {
|
||||
r = SSL_write(c->ssl, (void*)(c->tcp_write_pkt + c->tcp_write_byte_count - 2),
|
||||
(int)(c->tcp_write_pkt_len + 2 - c->tcp_write_byte_count));
|
||||
} else {
|
||||
r = SSL_write(c->ssl, (void*)sldns_buffer_current(c->buffer),
|
||||
(int)sldns_buffer_remaining(c->buffer));
|
||||
}
|
||||
if(r <= 0) {
|
||||
int want = SSL_get_error(c->ssl, r);
|
||||
if(want == SSL_ERROR_ZERO_RETURN) {
|
||||
|
|
@ -1385,9 +1416,13 @@ ssl_handle_write(struct comm_point* c)
|
|||
log_crypto_err("could not SSL_write");
|
||||
return 0;
|
||||
}
|
||||
sldns_buffer_skip(c->buffer, (ssize_t)r);
|
||||
if(c->tcp_write_and_read) {
|
||||
c->tcp_write_byte_count += r;
|
||||
} else {
|
||||
sldns_buffer_skip(c->buffer, (ssize_t)r);
|
||||
}
|
||||
|
||||
if(sldns_buffer_remaining(c->buffer) == 0) {
|
||||
if((!c->tcp_write_and_read && sldns_buffer_remaining(c->buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) {
|
||||
tcp_callback_writer(c);
|
||||
}
|
||||
return 1;
|
||||
|
|
@ -1531,7 +1566,7 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c)
|
|||
if(c->tcp_is_reading && !c->ssl)
|
||||
return 0;
|
||||
log_assert(fd != -1);
|
||||
if(c->tcp_byte_count == 0 && c->tcp_check_nb_connect) {
|
||||
if(((!c->tcp_write_and_read && c->tcp_byte_count == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == 0)) && c->tcp_check_nb_connect) {
|
||||
/* check for pending error from nonblocking connect */
|
||||
/* from Stevens, unix network programming, vol1, 3rd ed, p450*/
|
||||
int error = 0;
|
||||
|
|
@ -1581,15 +1616,22 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c)
|
|||
if(c->tcp_do_fastopen == 1) {
|
||||
/* this form of sendmsg() does both a connect() and send() so need to
|
||||
look for various flavours of error*/
|
||||
uint16_t len = htons(sldns_buffer_limit(buffer));
|
||||
uint16_t len = htons(c->tcp_write_and_read?c->tcp_write_pkt_len:sldns_buffer_limit(buffer));
|
||||
struct msghdr msg;
|
||||
struct iovec iov[2];
|
||||
c->tcp_do_fastopen = 0;
|
||||
memset(&msg, 0, sizeof(msg));
|
||||
iov[0].iov_base = (uint8_t*)&len + c->tcp_byte_count;
|
||||
iov[0].iov_len = sizeof(uint16_t) - c->tcp_byte_count;
|
||||
iov[1].iov_base = sldns_buffer_begin(buffer);
|
||||
iov[1].iov_len = sldns_buffer_limit(buffer);
|
||||
if(c->tcp_write_and_read) {
|
||||
iov[0].iov_base = (uint8_t*)&len + c->tcp_write_byte_count;
|
||||
iov[0].iov_len = sizeof(uint16_t) - c->tcp_write_byte_count;
|
||||
iov[1].iov_base = c->tcp_write_pkt;
|
||||
iov[1].iov_len = c->tcp_write_pkt_len;
|
||||
} else {
|
||||
iov[0].iov_base = (uint8_t*)&len + c->tcp_byte_count;
|
||||
iov[0].iov_len = sizeof(uint16_t) - c->tcp_byte_count;
|
||||
iov[1].iov_base = sldns_buffer_begin(buffer);
|
||||
iov[1].iov_len = sldns_buffer_limit(buffer);
|
||||
}
|
||||
log_assert(iov[0].iov_len > 0);
|
||||
msg.msg_name = &c->repinfo.addr;
|
||||
msg.msg_namelen = c->repinfo.addrlen;
|
||||
|
|
@ -1635,12 +1677,18 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c)
|
|||
}
|
||||
|
||||
} else {
|
||||
c->tcp_byte_count += r;
|
||||
if(c->tcp_byte_count < sizeof(uint16_t))
|
||||
return 1;
|
||||
sldns_buffer_set_position(buffer, c->tcp_byte_count -
|
||||
sizeof(uint16_t));
|
||||
if(sldns_buffer_remaining(buffer) == 0) {
|
||||
if(c->tcp_write_and_read) {
|
||||
c->tcp_write_byte_count += r;
|
||||
if(c->tcp_write_byte_count < sizeof(uint16_t))
|
||||
return 1;
|
||||
} else {
|
||||
c->tcp_byte_count += r;
|
||||
if(c->tcp_byte_count < sizeof(uint16_t))
|
||||
return 1;
|
||||
sldns_buffer_set_position(buffer, c->tcp_byte_count -
|
||||
sizeof(uint16_t));
|
||||
}
|
||||
if((!c->tcp_write_and_read && sldns_buffer_remaining(buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) {
|
||||
tcp_callback_writer(c);
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -1648,19 +1696,31 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c)
|
|||
}
|
||||
#endif /* USE_MSG_FASTOPEN */
|
||||
|
||||
if(c->tcp_byte_count < sizeof(uint16_t)) {
|
||||
uint16_t len = htons(sldns_buffer_limit(buffer));
|
||||
if((c->tcp_write_and_read?c->tcp_write_byte_count:c->tcp_byte_count) < sizeof(uint16_t)) {
|
||||
uint16_t len = htons(c->tcp_write_and_read?c->tcp_write_pkt_len:sldns_buffer_limit(buffer));
|
||||
#ifdef HAVE_WRITEV
|
||||
struct iovec iov[2];
|
||||
iov[0].iov_base = (uint8_t*)&len + c->tcp_byte_count;
|
||||
iov[0].iov_len = sizeof(uint16_t) - c->tcp_byte_count;
|
||||
iov[1].iov_base = sldns_buffer_begin(buffer);
|
||||
iov[1].iov_len = sldns_buffer_limit(buffer);
|
||||
if(c->tcp_write_and_read) {
|
||||
iov[0].iov_base = (uint8_t*)&len + c->tcp_write_byte_count;
|
||||
iov[0].iov_len = sizeof(uint16_t) - c->tcp_write_byte_count;
|
||||
iov[1].iov_base = c->tcp_write_pkt;
|
||||
iov[1].iov_len = c->tcp_write_pkt_len;
|
||||
} else {
|
||||
iov[0].iov_base = (uint8_t*)&len + c->tcp_byte_count;
|
||||
iov[0].iov_len = sizeof(uint16_t) - c->tcp_byte_count;
|
||||
iov[1].iov_base = sldns_buffer_begin(buffer);
|
||||
iov[1].iov_len = sldns_buffer_limit(buffer);
|
||||
}
|
||||
log_assert(iov[0].iov_len > 0);
|
||||
r = writev(fd, iov, 2);
|
||||
#else /* HAVE_WRITEV */
|
||||
r = send(fd, (void*)(((uint8_t*)&len)+c->tcp_byte_count),
|
||||
sizeof(uint16_t)-c->tcp_byte_count, 0);
|
||||
if(c->tcp_write_and_read) {
|
||||
r = send(fd, (void*)(((uint8_t*)&len)+c->tcp_write_byte_count),
|
||||
sizeof(uint16_t)-c->tcp_write_byte_count, 0);
|
||||
} else {
|
||||
r = send(fd, (void*)(((uint8_t*)&len)+c->tcp_byte_count),
|
||||
sizeof(uint16_t)-c->tcp_byte_count, 0);
|
||||
}
|
||||
#endif /* HAVE_WRITEV */
|
||||
if(r == -1) {
|
||||
#ifndef USE_WINSOCK
|
||||
|
|
@ -1699,19 +1759,31 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c)
|
|||
#endif
|
||||
return 0;
|
||||
}
|
||||
c->tcp_byte_count += r;
|
||||
if(c->tcp_byte_count < sizeof(uint16_t))
|
||||
return 1;
|
||||
sldns_buffer_set_position(buffer, c->tcp_byte_count -
|
||||
sizeof(uint16_t));
|
||||
if(sldns_buffer_remaining(buffer) == 0) {
|
||||
if(c->tcp_write_and_read) {
|
||||
c->tcp_write_byte_count += r;
|
||||
if(c->tcp_write_byte_count < sizeof(uint16_t))
|
||||
return 1;
|
||||
} else {
|
||||
c->tcp_byte_count += r;
|
||||
if(c->tcp_byte_count < sizeof(uint16_t))
|
||||
return 1;
|
||||
sldns_buffer_set_position(buffer, c->tcp_byte_count -
|
||||
sizeof(uint16_t));
|
||||
}
|
||||
if((!c->tcp_write_and_read && sldns_buffer_remaining(buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) {
|
||||
tcp_callback_writer(c);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
log_assert(sldns_buffer_remaining(buffer) > 0);
|
||||
r = send(fd, (void*)sldns_buffer_current(buffer),
|
||||
sldns_buffer_remaining(buffer), 0);
|
||||
log_assert(c->tcp_write_and_read || sldns_buffer_remaining(buffer) > 0);
|
||||
log_assert(!c->tcp_write_and_read || c->tcp_write_byte_count < c->tcp_write_pkt_len + 2);
|
||||
if(c->tcp_write_and_read) {
|
||||
r = send(fd, (void*)c->tcp_write_pkt + c->tcp_write_byte_count - 2,
|
||||
c->tcp_write_pkt_len + 2 - c->tcp_write_byte_count, 0);
|
||||
} else {
|
||||
r = send(fd, (void*)sldns_buffer_current(buffer),
|
||||
sldns_buffer_remaining(buffer), 0);
|
||||
}
|
||||
if(r == -1) {
|
||||
#ifndef USE_WINSOCK
|
||||
if(errno == EINTR || errno == EAGAIN)
|
||||
|
|
@ -1736,9 +1808,13 @@ comm_point_tcp_handle_write(int fd, struct comm_point* c)
|
|||
#endif
|
||||
return 0;
|
||||
}
|
||||
sldns_buffer_skip(buffer, r);
|
||||
if(c->tcp_write_and_read) {
|
||||
c->tcp_write_byte_count += r;
|
||||
} else {
|
||||
sldns_buffer_skip(buffer, r);
|
||||
}
|
||||
|
||||
if(sldns_buffer_remaining(buffer) == 0) {
|
||||
if((!c->tcp_write_and_read && sldns_buffer_remaining(buffer) == 0) || (c->tcp_write_and_read && c->tcp_write_byte_count == c->tcp_write_pkt_len + 2)) {
|
||||
tcp_callback_writer(c);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -254,11 +254,16 @@ struct comm_point {
|
|||
int tcp_write_and_read;
|
||||
|
||||
/** byte count for written length over write channel, for when
|
||||
* tcp_write_and_read is enabled */
|
||||
* tcp_write_and_read is enabled. When tcp_write_and_read is enabled,
|
||||
* this is the counter for writing, the one for reading is in the
|
||||
* commpoint.buffer sldns buffer. The counter counts from 0 to
|
||||
* 2+tcp_write_pkt_len, and includes the tcp length bytes. */
|
||||
size_t tcp_write_byte_count;
|
||||
|
||||
/** packet to write currently over the write channel. for when
|
||||
* tcp_write_and_read is enabled */
|
||||
* tcp_write_and_read is enabled. When tcp_write_and_read is enabled,
|
||||
* this is the buffer for the written packet, the commpoint.buffer
|
||||
* sldns buffer is the buffer for the received packet. */
|
||||
uint8_t* tcp_write_pkt;
|
||||
/** length of tcp_write_pkt in bytes */
|
||||
size_t tcp_write_pkt_len;
|
||||
|
|
|
|||
Loading…
Reference in a new issue