aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--contrib/Socket.cpp120
-rw-r--r--contrib/Socket.h1
2 files changed, 120 insertions, 1 deletions
diff --git a/contrib/Socket.cpp b/contrib/Socket.cpp
index 7ff6b5e..d12c970 100644
--- a/contrib/Socket.cpp
+++ b/contrib/Socket.cpp
@@ -409,6 +409,121 @@ bool TCPSocket::valid() const
return m_sock != -1;
}
+void TCPSocket::connect(const std::string& hostname, int port, int timeout_ms)
+{
+ if (m_sock != INVALID_SOCKET) {
+ throw std::logic_error("You may only connect an invalid TCPSocket");
+ }
+
+ char service[NI_MAXSERV];
+ snprintf(service, NI_MAXSERV-1, "%d", port);
+
+ /* Obtain address(es) matching host/port */
+ struct addrinfo hints;
+ memset(&hints, 0, sizeof(struct addrinfo));
+ hints.ai_family = AF_INET;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_flags = 0;
+ hints.ai_protocol = 0;
+
+ struct addrinfo *result, *rp;
+ int s = getaddrinfo(hostname.c_str(), service, &hints, &result);
+ if (s != 0) {
+ throw runtime_error(string("getaddrinfo failed: ") + gai_strerror(s));
+ }
+
+ int flags = 0;
+
+ /* getaddrinfo() returns a list of address structures.
+ Try each address until we successfully connect(2).
+ If socket(2) (or connect(2)) fails, we (close the socket
+ and) try the next address. */
+
+ for (rp = result; rp != nullptr; rp = rp->ai_next) {
+ int sfd = ::socket(rp->ai_family, rp->ai_socktype,
+ rp->ai_protocol);
+ if (sfd == -1)
+ continue;
+
+ flags = fcntl(sfd, F_GETFL);
+ if (flags == -1) {
+ std::string errstr(strerror(errno));
+ throw std::runtime_error("TCP: Could not get socket flags: " + errstr);
+ }
+
+ if (fcntl(sfd, F_SETFL, flags | O_NONBLOCK) == -1) {
+ std::string errstr(strerror(errno));
+ throw std::runtime_error("TCP: Could not set O_NONBLOCK: " + errstr);
+ }
+
+ int ret = ::connect(sfd, rp->ai_addr, rp->ai_addrlen);
+ if (ret == 0) {
+ m_sock = sfd;
+ break;
+ }
+ if (ret == -1 and errno == EINPROGRESS) {
+ m_sock = sfd;
+ struct pollfd fds[1];
+ fds[0].fd = m_sock;
+ fds[0].events = POLLOUT;
+
+ int retval = poll(fds, 1, timeout_ms);
+
+ if (retval == -1) {
+ std::string errstr(strerror(errno));
+ ::close(m_sock);
+ freeaddrinfo(result);
+ throw runtime_error("TCP: connect error on poll: " + errstr);
+ }
+ else if (retval > 0) {
+ int so_error = 0;
+ socklen_t len = sizeof(so_error);
+
+ if (getsockopt(m_sock, SOL_SOCKET, SO_ERROR, &so_error, &len) == -1) {
+ std::string errstr(strerror(errno));
+ ::close(m_sock);
+ freeaddrinfo(result);
+ throw runtime_error("TCP: getsockopt error connect: " + errstr);
+ }
+
+ if (so_error == 0) {
+ break;
+ }
+ }
+ else {
+ ::close(m_sock);
+ freeaddrinfo(result);
+ throw runtime_error("Timeout on connect");
+ }
+ break;
+ }
+
+ ::close(sfd);
+ }
+
+ if (m_sock != INVALID_SOCKET) {
+#if defined(HAVE_SO_NOSIGPIPE)
+ int val = 1;
+ if (setsockopt(m_sock, SOL_SOCKET, SO_NOSIGPIPE, &val, sizeof(val))
+ == SOCKET_ERROR) {
+ throw runtime_error("Can't set SO_NOSIGPIPE");
+ }
+#endif
+ }
+
+ // Don't keep the socket blocking
+ if (fcntl(m_sock, F_SETFL, flags) == -1) {
+ std::string errstr(strerror(errno));
+ throw std::runtime_error("TCP: Could not set O_NONBLOCK: " + errstr);
+ }
+
+ freeaddrinfo(result);
+
+ if (rp == nullptr) {
+ throw runtime_error("Could not connect");
+ }
+}
+
void TCPSocket::connect(const std::string& hostname, int port, bool nonblock)
{
if (m_sock != INVALID_SOCKET) {
@@ -447,11 +562,15 @@ void TCPSocket::connect(const std::string& hostname, int port, bool nonblock)
int flags = fcntl(sfd, F_GETFL);
if (flags == -1) {
std::string errstr(strerror(errno));
+ freeaddrinfo(result);
+ ::close(sfd);
throw std::runtime_error("TCP: Could not get socket flags: " + errstr);
}
if (fcntl(sfd, F_SETFL, flags | O_NONBLOCK) == -1) {
std::string errstr(strerror(errno));
+ freeaddrinfo(result);
+ ::close(sfd);
throw std::runtime_error("TCP: Could not set O_NONBLOCK: " + errstr);
}
}
@@ -480,7 +599,6 @@ void TCPSocket::connect(const std::string& hostname, int port, bool nonblock)
if (rp == nullptr) {
throw runtime_error("Could not connect");
}
-
}
void TCPSocket::listen(int port, const string& name)
diff --git a/contrib/Socket.h b/contrib/Socket.h
index 33cdc05..08607a5 100644
--- a/contrib/Socket.h
+++ b/contrib/Socket.h
@@ -168,6 +168,7 @@ class TCPSocket {
bool valid(void) const;
void connect(const std::string& hostname, int port, bool nonblock = false);
+ void connect(const std::string& hostname, int port, int timeout_ms);
void listen(int port, const std::string& name);
void close(void);