diff options
| author | shunt010 <sam@maxxwave.co.uk> | 2025-12-31 13:19:28 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-12-31 13:19:28 +0000 |
| commit | f8eaf51f61cdae65e90675920e427d23b8da7027 (patch) | |
| tree | e70c92214ac05cf0a2001e0481e289343c094d65 /lib/Socket.cpp | |
| parent | f8b5402727b7e94aecbfb663a601577f97bae5b9 (diff) | |
| parent | a5f80a99e0dad51c45e8511347f27d816ae92e20 (diff) | |
| download | dabmux-f8eaf51f61cdae65e90675920e427d23b8da7027.tar.gz dabmux-f8eaf51f61cdae65e90675920e427d23b8da7027.tar.bz2 dabmux-f8eaf51f61cdae65e90675920e427d23b8da7027.zip | |
Merge pull request #1 from Opendigitalradio/master
Bring up to date
Diffstat (limited to 'lib/Socket.cpp')
| -rw-r--r-- | lib/Socket.cpp | 198 |
1 files changed, 165 insertions, 33 deletions
diff --git a/lib/Socket.cpp b/lib/Socket.cpp index 10ec1ca..33c9c73 100644 --- a/lib/Socket.cpp +++ b/lib/Socket.cpp @@ -24,7 +24,8 @@ #include "Socket.h" -#include <iostream> +#include <numeric> +#include <stdexcept> #include <cstdio> #include <cstring> #include <cerrno> @@ -106,16 +107,20 @@ UDPSocket::UDPSocket(UDPSocket&& other) { m_sock = other.m_sock; m_port = other.m_port; + m_multicast_source = other.m_multicast_source; other.m_port = 0; other.m_sock = INVALID_SOCKET; + other.m_multicast_source = ""; } const UDPSocket& UDPSocket::operator=(UDPSocket&& other) { m_sock = other.m_sock; m_port = other.m_port; + m_multicast_source = other.m_multicast_source; other.m_port = 0; other.m_sock = INVALID_SOCKET; + other.m_multicast_source = ""; return *this; } @@ -144,6 +149,7 @@ void UDPSocket::reinit(int port, const std::string& name) // No need to bind to a given port, creating the // socket is enough m_sock = ::socket(AF_INET, SOCK_DGRAM, 0); + post_init(); return; } @@ -180,6 +186,7 @@ void UDPSocket::reinit(int port, const std::string& name) if (::bind(sfd, rp->ai_addr, rp->ai_addrlen) == 0) { m_sock = sfd; + post_init(); break; } @@ -189,10 +196,47 @@ void UDPSocket::reinit(int port, const std::string& name) freeaddrinfo(result); if (rp == nullptr) { - throw runtime_error("Could not bind"); + throw runtime_error(string{"Could not bind to port "} + to_string(port)); } } +void UDPSocket::post_init() { + int pktinfo = 1; + if (setsockopt(m_sock, IPPROTO_IP, IP_PKTINFO, &pktinfo, sizeof(pktinfo)) == SOCKET_ERROR) { + throw runtime_error(string("Can't request pktinfo: ") + strerror(errno)); + } + +} + +void UDPSocket::init_receive_multicast(int port, const string& local_if_addr, const string& mcastaddr) +{ + if (m_sock != INVALID_SOCKET) { + ::close(m_sock); + } + + m_port = port; + m_sock = ::socket(AF_INET, SOCK_DGRAM, 0); + post_init(); + + int reuse_setting = 1; + if (setsockopt(m_sock, SOL_SOCKET, SO_REUSEADDR, &reuse_setting, sizeof(reuse_setting)) == SOCKET_ERROR) { + throw runtime_error("Can't reuse address"); + } + + struct sockaddr_in la; + memset((char *) &la, 0, sizeof(la)); + la.sin_family = AF_INET; + la.sin_port = htons(port); + la.sin_addr.s_addr = INADDR_ANY; + if (::bind(m_sock, (struct sockaddr*)&la, sizeof(la))) { + throw runtime_error(string("Could not bind: ") + strerror(errno)); + } + + m_multicast_source = mcastaddr; + join_group(mcastaddr.c_str(), local_if_addr.c_str()); +} + + void UDPSocket::close() { if (m_sock != INVALID_SOCKET) { @@ -212,16 +256,26 @@ UDPSocket::~UDPSocket() UDPPacket UDPSocket::receive(size_t max_size) { + struct sockaddr_in addr; + struct msghdr msg; + struct iovec iov; + constexpr size_t BUFFER_SIZE = 1024; + char control_buffer[BUFFER_SIZE]; + struct cmsghdr *cmsg; + UDPPacket packet(max_size); - socklen_t addrSize; - addrSize = sizeof(*packet.address.as_sockaddr()); - ssize_t ret = recvfrom(m_sock, - packet.buffer.data(), - packet.buffer.size(), - 0, - packet.address.as_sockaddr(), - &addrSize); + memset(&msg, 0, sizeof(msg)); + msg.msg_name = &addr; + msg.msg_namelen = sizeof(addr); + msg.msg_iov = &iov; + iov.iov_base = packet.buffer.data(); + iov.iov_len = packet.buffer.size(); + msg.msg_iovlen = 1; + msg.msg_control = control_buffer; + msg.msg_controllen = sizeof(control_buffer); + + ssize_t ret = recvmsg(m_sock, &msg, 0); if (ret == SOCKET_ERROR) { packet.buffer.resize(0); @@ -232,12 +286,42 @@ UDPPacket UDPSocket::receive(size_t max_size) if (errno == EAGAIN or errno == EWOULDBLOCK) #endif { - return 0; + return packet; } throw runtime_error(string("Can't receive data: ") + strerror(errno)); } - packet.buffer.resize(ret); + struct in_pktinfo *pktinfo = nullptr; + for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; cmsg = CMSG_NXTHDR(&msg, cmsg)) { + if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_PKTINFO) { + pktinfo = (struct in_pktinfo *)CMSG_DATA(cmsg); + break; + } + } + + if (pktinfo) { + char src_addr[INET_ADDRSTRLEN]; + char dst_addr[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &(addr.sin_addr), src_addr, INET_ADDRSTRLEN); + inet_ntop(AF_INET, &(pktinfo->ipi_addr), dst_addr, INET_ADDRSTRLEN); + //fprintf(stderr, "Received packet from %s to %s: %zu\n", src_addr, dst_addr, ret); + + memcpy(&packet.address.addr, &addr, sizeof(addr)); + + if (m_multicast_source.empty() or + strcmp(dst_addr, m_multicast_source.c_str()) == 0) { + packet.buffer.resize(ret); + } + else { + // Ignore packet for different multicast group + packet.buffer.resize(0); + } + } + else { + //fprintf(stderr, "No pktinfo: %zu\n", ret); + packet.buffer.resize(ret); + } + return packet; } @@ -269,14 +353,14 @@ void UDPSocket::send(const std::string& data, InetAddress destination) } } -void UDPSocket::joinGroup(const char* groupname, const char* if_addr) +void UDPSocket::join_group(const char* groupname, const char* if_addr) { ip_mreqn group; if ((group.imr_multiaddr.s_addr = inet_addr(groupname)) == INADDR_NONE) { throw runtime_error("Cannot convert multicast group name"); } if (!IN_MULTICAST(ntohl(group.imr_multiaddr.s_addr))) { - throw runtime_error("Group name is not a multicast address"); + throw runtime_error(string("Group name '") + groupname + "' is not a multicast address"); } if (if_addr) { @@ -288,7 +372,7 @@ void UDPSocket::joinGroup(const char* groupname, const char* if_addr) group.imr_ifindex = 0; if (setsockopt(m_sock, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)) == SOCKET_ERROR) { - throw runtime_error(string("Can't join multicast group") + strerror(errno)); + throw runtime_error(string("Can't join multicast group: ") + strerror(errno)); } } @@ -296,12 +380,12 @@ void UDPSocket::setMulticastSource(const char* source_addr) { struct in_addr addr; if (inet_aton(source_addr, &addr) == 0) { - throw runtime_error(string("Can't parse source address") + strerror(errno)); + throw runtime_error(string("Can't parse source address: ") + strerror(errno)); } if (setsockopt(m_sock, IPPROTO_IP, IP_MULTICAST_IF, &addr, sizeof(addr)) == SOCKET_ERROR) { - throw runtime_error(string("Can't set source address") + strerror(errno)); + throw runtime_error(string("Can't set source address: ") + strerror(errno)); } } @@ -309,7 +393,7 @@ void UDPSocket::setMulticastTTL(int ttl) { if (setsockopt(m_sock, IPPROTO_IP, IP_MULTICAST_TTL, &ttl, sizeof(ttl)) == SOCKET_ERROR) { - throw runtime_error(string("Can't set multicast ttl") + strerror(errno)); + throw runtime_error(string("Can't set multicast ttl: ") + strerror(errno)); } } @@ -327,15 +411,13 @@ void UDPReceiver::add_receive_port(int port, const string& bindto, const string& UDPSocket sock; if (IN_MULTICAST(ntohl(inet_addr(mcastaddr.c_str())))) { - sock.reinit(port, mcastaddr); - sock.setMulticastSource(bindto.c_str()); - sock.joinGroup(mcastaddr.c_str(), bindto.c_str()); + sock.init_receive_multicast(port, bindto, mcastaddr); } else { sock.reinit(port, bindto); } - m_sockets.push_back(move(sock)); + m_sockets.push_back(std::move(sock)); } vector<UDPReceiver::ReceivedPacket> UDPReceiver::receive(int timeout_ms) @@ -366,11 +448,13 @@ vector<UDPReceiver::ReceivedPacket> UDPReceiver::receive(int timeout_ms) for (size_t i = 0; i < m_sockets.size(); i++) { if (fds[i].revents & POLLIN) { auto p = m_sockets[i].receive(2048); // This is larger than the usual MTU - ReceivedPacket rp; - rp.packetdata = move(p.buffer); - rp.received_from = move(p.address); - rp.port_received_on = m_sockets[i].getPort(); - received.push_back(move(rp)); + if (not p.buffer.empty()) { + ReceivedPacket rp; + rp.packetdata = std::move(p.buffer); + rp.received_from = std::move(p.address); + rp.port_received_on = m_sockets[i].getPort(); + received.push_back(std::move(rp)); + } } } @@ -395,7 +479,7 @@ TCPSocket::~TCPSocket() TCPSocket::TCPSocket(TCPSocket&& other) : m_sock(other.m_sock), - m_remote_address(move(other.m_remote_address)) + m_remote_address(std::move(other.m_remote_address)) { if (other.m_sock != -1) { other.m_sock = -1; @@ -884,22 +968,33 @@ ssize_t TCPClient::recv(void *buffer, size_t length, int flags, int timeout_ms) reconnect(); } + m_last_received_packet_ts = chrono::steady_clock::now(); + return ret; } catch (const TCPSocket::Interrupted&) { return -1; } catch (const TCPSocket::Timeout&) { + const auto timeout = chrono::milliseconds(timeout_ms * 5); + if (m_last_received_packet_ts.has_value() and + chrono::steady_clock::now() - *m_last_received_packet_ts > timeout) + { + // This is to catch half-closed TCP connections + reconnect(); + } + return 0; } - return 0; + throw std::logic_error("unreachable"); } void TCPClient::reconnect() { TCPSocket newsock; m_sock = std::move(newsock); + m_last_received_packet_ts = nullopt; m_sock.connect(m_hostname, m_port, true); } @@ -907,7 +1002,7 @@ TCPConnection::TCPConnection(TCPSocket&& sock) : queue(), m_running(true), m_sender_thread(), - m_sock(move(sock)) + m_sock(std::move(sock)) { #if MISSING_OWN_ADDR auto own_addr = m_sock.getOwnAddress(); @@ -969,6 +1064,17 @@ void TCPConnection::process() #endif } +TCPConnection::stats_t TCPConnection::get_stats() const +{ + TCPConnection::stats_t s; + const vector<size_t> buffer_sizes = queue.map<size_t>( + [](const vector<uint8_t>& vec) { return vec.size(); } + ); + + s.buffer_fullness = std::accumulate(buffer_sizes.cbegin(), buffer_sizes.cend(), 0); + s.remote_address = m_sock.get_remote_address(); + return s; +} TCPDataDispatcher::TCPDataDispatcher(size_t max_queue_size, size_t buffers_to_preroll) : m_max_queue_size(max_queue_size), @@ -1026,7 +1132,7 @@ void TCPDataDispatcher::process() auto sock = m_listener_socket.accept(timeout_ms); if (sock.valid()) { auto lock = unique_lock<mutex>(m_mutex); - m_connections.emplace(m_connections.begin(), move(sock)); + m_connections.emplace(m_connections.begin(), std::move(sock)); if (m_buffers_to_preroll > 0) { for (const auto& buf : m_preroll_queue) { @@ -1042,6 +1148,17 @@ void TCPDataDispatcher::process() } } + +std::vector<TCPConnection::stats_t> TCPDataDispatcher::get_stats() const +{ + std::vector<TCPConnection::stats_t> s; + auto lock = unique_lock<mutex>(m_mutex); + for (const auto& conn : m_connections) { + s.push_back(conn.get_stats()); + } + return s; +} + TCPReceiveServer::TCPReceiveServer(size_t blocksize) : m_blocksize(blocksize) { @@ -1098,7 +1215,7 @@ void TCPReceiveServer::process() } else { buf.resize(r); - m_queue.push(make_shared<TCPReceiveMessageData>(move(buf))); + m_queue.push(make_shared<TCPReceiveMessageData>(std::move(buf))); } } catch (const TCPSocket::Interrupted&) { @@ -1139,7 +1256,7 @@ TCPSendClient::~TCPSendClient() } } -void TCPSendClient::sendall(const std::vector<uint8_t>& buffer) +TCPSendClient::ErrorStats TCPSendClient::sendall(const std::vector<uint8_t>& buffer) { if (not m_running) { throw runtime_error(m_exception_data); @@ -1151,6 +1268,17 @@ void TCPSendClient::sendall(const std::vector<uint8_t>& buffer) vector<uint8_t> discard; m_queue.try_pop(discard); } + + TCPSendClient::ErrorStats es; + es.num_reconnects = m_num_reconnects.load(); + + es.has_seen_new_errors = es.num_reconnects != m_num_reconnects_prev; + m_num_reconnects_prev = es.num_reconnects; + + auto lock = unique_lock<mutex>(m_error_mutex); + es.last_error = m_last_error; + + return es; } void TCPSendClient::process() @@ -1172,12 +1300,16 @@ void TCPSendClient::process() } else { try { + m_num_reconnects.fetch_add(1, std::memory_order_seq_cst); m_sock.connect(m_hostname, m_port); m_is_connected = true; } catch (const runtime_error& e) { m_is_connected = false; this_thread::sleep_for(chrono::seconds(1)); + + auto lock = unique_lock<mutex>(m_error_mutex); + m_last_error = e.what(); } } } |
