diff options
-rw-r--r-- | lib/Socket.cpp | 120 | ||||
-rw-r--r-- | lib/Socket.h | 1 | ||||
-rw-r--r-- | lib/edi/common.cpp | 7 |
3 files changed, 127 insertions, 1 deletions
diff --git a/lib/Socket.cpp b/lib/Socket.cpp index 7ff6b5e..d12c970 100644 --- a/lib/Socket.cpp +++ b/lib/Socket.cpp @@ -409,6 +409,121 @@ bool TCPSocket::valid() const return m_sock != -1; } +void TCPSocket::connect(const std::string& hostname, int port, int timeout_ms) +{ + if (m_sock != INVALID_SOCKET) { + throw std::logic_error("You may only connect an invalid TCPSocket"); + } + + char service[NI_MAXSERV]; + snprintf(service, NI_MAXSERV-1, "%d", port); + + /* Obtain address(es) matching host/port */ + struct addrinfo hints; + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = 0; + hints.ai_protocol = 0; + + struct addrinfo *result, *rp; + int s = getaddrinfo(hostname.c_str(), service, &hints, &result); + if (s != 0) { + throw runtime_error(string("getaddrinfo failed: ") + gai_strerror(s)); + } + + int flags = 0; + + /* getaddrinfo() returns a list of address structures. + Try each address until we successfully connect(2). + If socket(2) (or connect(2)) fails, we (close the socket + and) try the next address. */ + + for (rp = result; rp != nullptr; rp = rp->ai_next) { + int sfd = ::socket(rp->ai_family, rp->ai_socktype, + rp->ai_protocol); + if (sfd == -1) + continue; + + flags = fcntl(sfd, F_GETFL); + if (flags == -1) { + std::string errstr(strerror(errno)); + throw std::runtime_error("TCP: Could not get socket flags: " + errstr); + } + + if (fcntl(sfd, F_SETFL, flags | O_NONBLOCK) == -1) { + std::string errstr(strerror(errno)); + throw std::runtime_error("TCP: Could not set O_NONBLOCK: " + errstr); + } + + int ret = ::connect(sfd, rp->ai_addr, rp->ai_addrlen); + if (ret == 0) { + m_sock = sfd; + break; + } + if (ret == -1 and errno == EINPROGRESS) { + m_sock = sfd; + struct pollfd fds[1]; + fds[0].fd = m_sock; + fds[0].events = POLLOUT; + + int retval = poll(fds, 1, timeout_ms); + + if (retval == -1) { + std::string errstr(strerror(errno)); + ::close(m_sock); + freeaddrinfo(result); + throw runtime_error("TCP: connect error on poll: " + errstr); + } + else if (retval > 0) { + int so_error = 0; + socklen_t len = sizeof(so_error); + + if (getsockopt(m_sock, SOL_SOCKET, SO_ERROR, &so_error, &len) == -1) { + std::string errstr(strerror(errno)); + ::close(m_sock); + freeaddrinfo(result); + throw runtime_error("TCP: getsockopt error connect: " + errstr); + } + + if (so_error == 0) { + break; + } + } + else { + ::close(m_sock); + freeaddrinfo(result); + throw runtime_error("Timeout on connect"); + } + break; + } + + ::close(sfd); + } + + if (m_sock != INVALID_SOCKET) { +#if defined(HAVE_SO_NOSIGPIPE) + int val = 1; + if (setsockopt(m_sock, SOL_SOCKET, SO_NOSIGPIPE, &val, sizeof(val)) + == SOCKET_ERROR) { + throw runtime_error("Can't set SO_NOSIGPIPE"); + } +#endif + } + + // Don't keep the socket blocking + if (fcntl(m_sock, F_SETFL, flags) == -1) { + std::string errstr(strerror(errno)); + throw std::runtime_error("TCP: Could not set O_NONBLOCK: " + errstr); + } + + freeaddrinfo(result); + + if (rp == nullptr) { + throw runtime_error("Could not connect"); + } +} + void TCPSocket::connect(const std::string& hostname, int port, bool nonblock) { if (m_sock != INVALID_SOCKET) { @@ -447,11 +562,15 @@ void TCPSocket::connect(const std::string& hostname, int port, bool nonblock) int flags = fcntl(sfd, F_GETFL); if (flags == -1) { std::string errstr(strerror(errno)); + freeaddrinfo(result); + ::close(sfd); throw std::runtime_error("TCP: Could not get socket flags: " + errstr); } if (fcntl(sfd, F_SETFL, flags | O_NONBLOCK) == -1) { std::string errstr(strerror(errno)); + freeaddrinfo(result); + ::close(sfd); throw std::runtime_error("TCP: Could not set O_NONBLOCK: " + errstr); } } @@ -480,7 +599,6 @@ void TCPSocket::connect(const std::string& hostname, int port, bool nonblock) if (rp == nullptr) { throw runtime_error("Could not connect"); } - } void TCPSocket::listen(int port, const string& name) diff --git a/lib/Socket.h b/lib/Socket.h index 33cdc05..08607a5 100644 --- a/lib/Socket.h +++ b/lib/Socket.h @@ -168,6 +168,7 @@ class TCPSocket { bool valid(void) const; void connect(const std::string& hostname, int port, bool nonblock = false); + void connect(const std::string& hostname, int port, int timeout_ms); void listen(int port, const std::string& name); void close(void); diff --git a/lib/edi/common.cpp b/lib/edi/common.cpp index abaf2ed..c99997a 100644 --- a/lib/edi/common.cpp +++ b/lib/edi/common.cpp @@ -154,6 +154,7 @@ void TagDispatcher::push_bytes(const vector<uint8_t> &buf) while (m_input_data.size() > 2) { if (m_input_data[0] == 'A' and m_input_data[1] == 'F') { const auto r = decode_afpacket(m_input_data); + bool leave_loop = false; switch (r.st) { case decode_state_e::Ok: m_last_sequences.pseq_valid = false; @@ -161,9 +162,11 @@ void TagDispatcher::push_bytes(const vector<uint8_t> &buf) break; case decode_state_e::MissingData: /* Continue filling buffer */ + leave_loop = true; break; case decode_state_e::Error: m_last_sequences.pseq_valid = false; + leave_loop = true; break; } @@ -174,6 +177,10 @@ void TagDispatcher::push_bytes(const vector<uint8_t> &buf) back_inserter(remaining_data)); m_input_data = remaining_data; } + + if (leave_loop) { + break; + } } else if (m_input_data[0] == 'P' and m_input_data[1] == 'F') { PFT::Fragment fragment; |