summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--lib/Socket.cpp120
-rw-r--r--lib/Socket.h1
-rw-r--r--lib/edi/common.cpp7
3 files changed, 127 insertions, 1 deletions
diff --git a/lib/Socket.cpp b/lib/Socket.cpp
index 7ff6b5e..d12c970 100644
--- a/lib/Socket.cpp
+++ b/lib/Socket.cpp
@@ -409,6 +409,121 @@ bool TCPSocket::valid() const
return m_sock != -1;
}
+void TCPSocket::connect(const std::string& hostname, int port, int timeout_ms)
+{
+ if (m_sock != INVALID_SOCKET) {
+ throw std::logic_error("You may only connect an invalid TCPSocket");
+ }
+
+ char service[NI_MAXSERV];
+ snprintf(service, NI_MAXSERV-1, "%d", port);
+
+ /* Obtain address(es) matching host/port */
+ struct addrinfo hints;
+ memset(&hints, 0, sizeof(struct addrinfo));
+ hints.ai_family = AF_INET;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_flags = 0;
+ hints.ai_protocol = 0;
+
+ struct addrinfo *result, *rp;
+ int s = getaddrinfo(hostname.c_str(), service, &hints, &result);
+ if (s != 0) {
+ throw runtime_error(string("getaddrinfo failed: ") + gai_strerror(s));
+ }
+
+ int flags = 0;
+
+ /* getaddrinfo() returns a list of address structures.
+ Try each address until we successfully connect(2).
+ If socket(2) (or connect(2)) fails, we (close the socket
+ and) try the next address. */
+
+ for (rp = result; rp != nullptr; rp = rp->ai_next) {
+ int sfd = ::socket(rp->ai_family, rp->ai_socktype,
+ rp->ai_protocol);
+ if (sfd == -1)
+ continue;
+
+ flags = fcntl(sfd, F_GETFL);
+ if (flags == -1) {
+ std::string errstr(strerror(errno));
+ throw std::runtime_error("TCP: Could not get socket flags: " + errstr);
+ }
+
+ if (fcntl(sfd, F_SETFL, flags | O_NONBLOCK) == -1) {
+ std::string errstr(strerror(errno));
+ throw std::runtime_error("TCP: Could not set O_NONBLOCK: " + errstr);
+ }
+
+ int ret = ::connect(sfd, rp->ai_addr, rp->ai_addrlen);
+ if (ret == 0) {
+ m_sock = sfd;
+ break;
+ }
+ if (ret == -1 and errno == EINPROGRESS) {
+ m_sock = sfd;
+ struct pollfd fds[1];
+ fds[0].fd = m_sock;
+ fds[0].events = POLLOUT;
+
+ int retval = poll(fds, 1, timeout_ms);
+
+ if (retval == -1) {
+ std::string errstr(strerror(errno));
+ ::close(m_sock);
+ freeaddrinfo(result);
+ throw runtime_error("TCP: connect error on poll: " + errstr);
+ }
+ else if (retval > 0) {
+ int so_error = 0;
+ socklen_t len = sizeof(so_error);
+
+ if (getsockopt(m_sock, SOL_SOCKET, SO_ERROR, &so_error, &len) == -1) {
+ std::string errstr(strerror(errno));
+ ::close(m_sock);
+ freeaddrinfo(result);
+ throw runtime_error("TCP: getsockopt error connect: " + errstr);
+ }
+
+ if (so_error == 0) {
+ break;
+ }
+ }
+ else {
+ ::close(m_sock);
+ freeaddrinfo(result);
+ throw runtime_error("Timeout on connect");
+ }
+ break;
+ }
+
+ ::close(sfd);
+ }
+
+ if (m_sock != INVALID_SOCKET) {
+#if defined(HAVE_SO_NOSIGPIPE)
+ int val = 1;
+ if (setsockopt(m_sock, SOL_SOCKET, SO_NOSIGPIPE, &val, sizeof(val))
+ == SOCKET_ERROR) {
+ throw runtime_error("Can't set SO_NOSIGPIPE");
+ }
+#endif
+ }
+
+ // Don't keep the socket blocking
+ if (fcntl(m_sock, F_SETFL, flags) == -1) {
+ std::string errstr(strerror(errno));
+ throw std::runtime_error("TCP: Could not set O_NONBLOCK: " + errstr);
+ }
+
+ freeaddrinfo(result);
+
+ if (rp == nullptr) {
+ throw runtime_error("Could not connect");
+ }
+}
+
void TCPSocket::connect(const std::string& hostname, int port, bool nonblock)
{
if (m_sock != INVALID_SOCKET) {
@@ -447,11 +562,15 @@ void TCPSocket::connect(const std::string& hostname, int port, bool nonblock)
int flags = fcntl(sfd, F_GETFL);
if (flags == -1) {
std::string errstr(strerror(errno));
+ freeaddrinfo(result);
+ ::close(sfd);
throw std::runtime_error("TCP: Could not get socket flags: " + errstr);
}
if (fcntl(sfd, F_SETFL, flags | O_NONBLOCK) == -1) {
std::string errstr(strerror(errno));
+ freeaddrinfo(result);
+ ::close(sfd);
throw std::runtime_error("TCP: Could not set O_NONBLOCK: " + errstr);
}
}
@@ -480,7 +599,6 @@ void TCPSocket::connect(const std::string& hostname, int port, bool nonblock)
if (rp == nullptr) {
throw runtime_error("Could not connect");
}
-
}
void TCPSocket::listen(int port, const string& name)
diff --git a/lib/Socket.h b/lib/Socket.h
index 33cdc05..08607a5 100644
--- a/lib/Socket.h
+++ b/lib/Socket.h
@@ -168,6 +168,7 @@ class TCPSocket {
bool valid(void) const;
void connect(const std::string& hostname, int port, bool nonblock = false);
+ void connect(const std::string& hostname, int port, int timeout_ms);
void listen(int port, const std::string& name);
void close(void);
diff --git a/lib/edi/common.cpp b/lib/edi/common.cpp
index abaf2ed..c99997a 100644
--- a/lib/edi/common.cpp
+++ b/lib/edi/common.cpp
@@ -154,6 +154,7 @@ void TagDispatcher::push_bytes(const vector<uint8_t> &buf)
while (m_input_data.size() > 2) {
if (m_input_data[0] == 'A' and m_input_data[1] == 'F') {
const auto r = decode_afpacket(m_input_data);
+ bool leave_loop = false;
switch (r.st) {
case decode_state_e::Ok:
m_last_sequences.pseq_valid = false;
@@ -161,9 +162,11 @@ void TagDispatcher::push_bytes(const vector<uint8_t> &buf)
break;
case decode_state_e::MissingData:
/* Continue filling buffer */
+ leave_loop = true;
break;
case decode_state_e::Error:
m_last_sequences.pseq_valid = false;
+ leave_loop = true;
break;
}
@@ -174,6 +177,10 @@ void TagDispatcher::push_bytes(const vector<uint8_t> &buf)
back_inserter(remaining_data));
m_input_data = remaining_data;
}
+
+ if (leave_loop) {
+ break;
+ }
}
else if (m_input_data[0] == 'P' and m_input_data[1] == 'F') {
PFT::Fragment fragment;