aboutsummaryrefslogtreecommitdiffstats
path: root/lib/Socket.cpp
diff options
context:
space:
mode:
authorshunt010 <sam@maxxwave.co.uk>2025-12-31 13:19:28 +0000
committerGitHub <noreply@github.com>2025-12-31 13:19:28 +0000
commitf8eaf51f61cdae65e90675920e427d23b8da7027 (patch)
treee70c92214ac05cf0a2001e0481e289343c094d65 /lib/Socket.cpp
parentf8b5402727b7e94aecbfb663a601577f97bae5b9 (diff)
parenta5f80a99e0dad51c45e8511347f27d816ae92e20 (diff)
downloaddabmux-f8eaf51f61cdae65e90675920e427d23b8da7027.tar.gz
dabmux-f8eaf51f61cdae65e90675920e427d23b8da7027.tar.bz2
dabmux-f8eaf51f61cdae65e90675920e427d23b8da7027.zip
Merge pull request #1 from Opendigitalradio/master
Bring up to date
Diffstat (limited to 'lib/Socket.cpp')
-rw-r--r--lib/Socket.cpp198
1 files changed, 165 insertions, 33 deletions
diff --git a/lib/Socket.cpp b/lib/Socket.cpp
index 10ec1ca..33c9c73 100644
--- a/lib/Socket.cpp
+++ b/lib/Socket.cpp
@@ -24,7 +24,8 @@
#include "Socket.h"
-#include <iostream>
+#include <numeric>
+#include <stdexcept>
#include <cstdio>
#include <cstring>
#include <cerrno>
@@ -106,16 +107,20 @@ UDPSocket::UDPSocket(UDPSocket&& other)
{
m_sock = other.m_sock;
m_port = other.m_port;
+ m_multicast_source = other.m_multicast_source;
other.m_port = 0;
other.m_sock = INVALID_SOCKET;
+ other.m_multicast_source = "";
}
const UDPSocket& UDPSocket::operator=(UDPSocket&& other)
{
m_sock = other.m_sock;
m_port = other.m_port;
+ m_multicast_source = other.m_multicast_source;
other.m_port = 0;
other.m_sock = INVALID_SOCKET;
+ other.m_multicast_source = "";
return *this;
}
@@ -144,6 +149,7 @@ void UDPSocket::reinit(int port, const std::string& name)
// No need to bind to a given port, creating the
// socket is enough
m_sock = ::socket(AF_INET, SOCK_DGRAM, 0);
+ post_init();
return;
}
@@ -180,6 +186,7 @@ void UDPSocket::reinit(int port, const std::string& name)
if (::bind(sfd, rp->ai_addr, rp->ai_addrlen) == 0) {
m_sock = sfd;
+ post_init();
break;
}
@@ -189,10 +196,47 @@ void UDPSocket::reinit(int port, const std::string& name)
freeaddrinfo(result);
if (rp == nullptr) {
- throw runtime_error("Could not bind");
+ throw runtime_error(string{"Could not bind to port "} + to_string(port));
}
}
+void UDPSocket::post_init() {
+ int pktinfo = 1;
+ if (setsockopt(m_sock, IPPROTO_IP, IP_PKTINFO, &pktinfo, sizeof(pktinfo)) == SOCKET_ERROR) {
+ throw runtime_error(string("Can't request pktinfo: ") + strerror(errno));
+ }
+
+}
+
+void UDPSocket::init_receive_multicast(int port, const string& local_if_addr, const string& mcastaddr)
+{
+ if (m_sock != INVALID_SOCKET) {
+ ::close(m_sock);
+ }
+
+ m_port = port;
+ m_sock = ::socket(AF_INET, SOCK_DGRAM, 0);
+ post_init();
+
+ int reuse_setting = 1;
+ if (setsockopt(m_sock, SOL_SOCKET, SO_REUSEADDR, &reuse_setting, sizeof(reuse_setting)) == SOCKET_ERROR) {
+ throw runtime_error("Can't reuse address");
+ }
+
+ struct sockaddr_in la;
+ memset((char *) &la, 0, sizeof(la));
+ la.sin_family = AF_INET;
+ la.sin_port = htons(port);
+ la.sin_addr.s_addr = INADDR_ANY;
+ if (::bind(m_sock, (struct sockaddr*)&la, sizeof(la))) {
+ throw runtime_error(string("Could not bind: ") + strerror(errno));
+ }
+
+ m_multicast_source = mcastaddr;
+ join_group(mcastaddr.c_str(), local_if_addr.c_str());
+}
+
+
void UDPSocket::close()
{
if (m_sock != INVALID_SOCKET) {
@@ -212,16 +256,26 @@ UDPSocket::~UDPSocket()
UDPPacket UDPSocket::receive(size_t max_size)
{
+ struct sockaddr_in addr;
+ struct msghdr msg;
+ struct iovec iov;
+ constexpr size_t BUFFER_SIZE = 1024;
+ char control_buffer[BUFFER_SIZE];
+ struct cmsghdr *cmsg;
+
UDPPacket packet(max_size);
- socklen_t addrSize;
- addrSize = sizeof(*packet.address.as_sockaddr());
- ssize_t ret = recvfrom(m_sock,
- packet.buffer.data(),
- packet.buffer.size(),
- 0,
- packet.address.as_sockaddr(),
- &addrSize);
+ memset(&msg, 0, sizeof(msg));
+ msg.msg_name = &addr;
+ msg.msg_namelen = sizeof(addr);
+ msg.msg_iov = &iov;
+ iov.iov_base = packet.buffer.data();
+ iov.iov_len = packet.buffer.size();
+ msg.msg_iovlen = 1;
+ msg.msg_control = control_buffer;
+ msg.msg_controllen = sizeof(control_buffer);
+
+ ssize_t ret = recvmsg(m_sock, &msg, 0);
if (ret == SOCKET_ERROR) {
packet.buffer.resize(0);
@@ -232,12 +286,42 @@ UDPPacket UDPSocket::receive(size_t max_size)
if (errno == EAGAIN or errno == EWOULDBLOCK)
#endif
{
- return 0;
+ return packet;
}
throw runtime_error(string("Can't receive data: ") + strerror(errno));
}
- packet.buffer.resize(ret);
+ struct in_pktinfo *pktinfo = nullptr;
+ for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+ if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_PKTINFO) {
+ pktinfo = (struct in_pktinfo *)CMSG_DATA(cmsg);
+ break;
+ }
+ }
+
+ if (pktinfo) {
+ char src_addr[INET_ADDRSTRLEN];
+ char dst_addr[INET_ADDRSTRLEN];
+ inet_ntop(AF_INET, &(addr.sin_addr), src_addr, INET_ADDRSTRLEN);
+ inet_ntop(AF_INET, &(pktinfo->ipi_addr), dst_addr, INET_ADDRSTRLEN);
+ //fprintf(stderr, "Received packet from %s to %s: %zu\n", src_addr, dst_addr, ret);
+
+ memcpy(&packet.address.addr, &addr, sizeof(addr));
+
+ if (m_multicast_source.empty() or
+ strcmp(dst_addr, m_multicast_source.c_str()) == 0) {
+ packet.buffer.resize(ret);
+ }
+ else {
+ // Ignore packet for different multicast group
+ packet.buffer.resize(0);
+ }
+ }
+ else {
+ //fprintf(stderr, "No pktinfo: %zu\n", ret);
+ packet.buffer.resize(ret);
+ }
+
return packet;
}
@@ -269,14 +353,14 @@ void UDPSocket::send(const std::string& data, InetAddress destination)
}
}
-void UDPSocket::joinGroup(const char* groupname, const char* if_addr)
+void UDPSocket::join_group(const char* groupname, const char* if_addr)
{
ip_mreqn group;
if ((group.imr_multiaddr.s_addr = inet_addr(groupname)) == INADDR_NONE) {
throw runtime_error("Cannot convert multicast group name");
}
if (!IN_MULTICAST(ntohl(group.imr_multiaddr.s_addr))) {
- throw runtime_error("Group name is not a multicast address");
+ throw runtime_error(string("Group name '") + groupname + "' is not a multicast address");
}
if (if_addr) {
@@ -288,7 +372,7 @@ void UDPSocket::joinGroup(const char* groupname, const char* if_addr)
group.imr_ifindex = 0;
if (setsockopt(m_sock, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group))
== SOCKET_ERROR) {
- throw runtime_error(string("Can't join multicast group") + strerror(errno));
+ throw runtime_error(string("Can't join multicast group: ") + strerror(errno));
}
}
@@ -296,12 +380,12 @@ void UDPSocket::setMulticastSource(const char* source_addr)
{
struct in_addr addr;
if (inet_aton(source_addr, &addr) == 0) {
- throw runtime_error(string("Can't parse source address") + strerror(errno));
+ throw runtime_error(string("Can't parse source address: ") + strerror(errno));
}
if (setsockopt(m_sock, IPPROTO_IP, IP_MULTICAST_IF, &addr, sizeof(addr))
== SOCKET_ERROR) {
- throw runtime_error(string("Can't set source address") + strerror(errno));
+ throw runtime_error(string("Can't set source address: ") + strerror(errno));
}
}
@@ -309,7 +393,7 @@ void UDPSocket::setMulticastTTL(int ttl)
{
if (setsockopt(m_sock, IPPROTO_IP, IP_MULTICAST_TTL, &ttl, sizeof(ttl))
== SOCKET_ERROR) {
- throw runtime_error(string("Can't set multicast ttl") + strerror(errno));
+ throw runtime_error(string("Can't set multicast ttl: ") + strerror(errno));
}
}
@@ -327,15 +411,13 @@ void UDPReceiver::add_receive_port(int port, const string& bindto, const string&
UDPSocket sock;
if (IN_MULTICAST(ntohl(inet_addr(mcastaddr.c_str())))) {
- sock.reinit(port, mcastaddr);
- sock.setMulticastSource(bindto.c_str());
- sock.joinGroup(mcastaddr.c_str(), bindto.c_str());
+ sock.init_receive_multicast(port, bindto, mcastaddr);
}
else {
sock.reinit(port, bindto);
}
- m_sockets.push_back(move(sock));
+ m_sockets.push_back(std::move(sock));
}
vector<UDPReceiver::ReceivedPacket> UDPReceiver::receive(int timeout_ms)
@@ -366,11 +448,13 @@ vector<UDPReceiver::ReceivedPacket> UDPReceiver::receive(int timeout_ms)
for (size_t i = 0; i < m_sockets.size(); i++) {
if (fds[i].revents & POLLIN) {
auto p = m_sockets[i].receive(2048); // This is larger than the usual MTU
- ReceivedPacket rp;
- rp.packetdata = move(p.buffer);
- rp.received_from = move(p.address);
- rp.port_received_on = m_sockets[i].getPort();
- received.push_back(move(rp));
+ if (not p.buffer.empty()) {
+ ReceivedPacket rp;
+ rp.packetdata = std::move(p.buffer);
+ rp.received_from = std::move(p.address);
+ rp.port_received_on = m_sockets[i].getPort();
+ received.push_back(std::move(rp));
+ }
}
}
@@ -395,7 +479,7 @@ TCPSocket::~TCPSocket()
TCPSocket::TCPSocket(TCPSocket&& other) :
m_sock(other.m_sock),
- m_remote_address(move(other.m_remote_address))
+ m_remote_address(std::move(other.m_remote_address))
{
if (other.m_sock != -1) {
other.m_sock = -1;
@@ -884,22 +968,33 @@ ssize_t TCPClient::recv(void *buffer, size_t length, int flags, int timeout_ms)
reconnect();
}
+ m_last_received_packet_ts = chrono::steady_clock::now();
+
return ret;
}
catch (const TCPSocket::Interrupted&) {
return -1;
}
catch (const TCPSocket::Timeout&) {
+ const auto timeout = chrono::milliseconds(timeout_ms * 5);
+ if (m_last_received_packet_ts.has_value() and
+ chrono::steady_clock::now() - *m_last_received_packet_ts > timeout)
+ {
+ // This is to catch half-closed TCP connections
+ reconnect();
+ }
+
return 0;
}
- return 0;
+ throw std::logic_error("unreachable");
}
void TCPClient::reconnect()
{
TCPSocket newsock;
m_sock = std::move(newsock);
+ m_last_received_packet_ts = nullopt;
m_sock.connect(m_hostname, m_port, true);
}
@@ -907,7 +1002,7 @@ TCPConnection::TCPConnection(TCPSocket&& sock) :
queue(),
m_running(true),
m_sender_thread(),
- m_sock(move(sock))
+ m_sock(std::move(sock))
{
#if MISSING_OWN_ADDR
auto own_addr = m_sock.getOwnAddress();
@@ -969,6 +1064,17 @@ void TCPConnection::process()
#endif
}
+TCPConnection::stats_t TCPConnection::get_stats() const
+{
+ TCPConnection::stats_t s;
+ const vector<size_t> buffer_sizes = queue.map<size_t>(
+ [](const vector<uint8_t>& vec) { return vec.size(); }
+ );
+
+ s.buffer_fullness = std::accumulate(buffer_sizes.cbegin(), buffer_sizes.cend(), 0);
+ s.remote_address = m_sock.get_remote_address();
+ return s;
+}
TCPDataDispatcher::TCPDataDispatcher(size_t max_queue_size, size_t buffers_to_preroll) :
m_max_queue_size(max_queue_size),
@@ -1026,7 +1132,7 @@ void TCPDataDispatcher::process()
auto sock = m_listener_socket.accept(timeout_ms);
if (sock.valid()) {
auto lock = unique_lock<mutex>(m_mutex);
- m_connections.emplace(m_connections.begin(), move(sock));
+ m_connections.emplace(m_connections.begin(), std::move(sock));
if (m_buffers_to_preroll > 0) {
for (const auto& buf : m_preroll_queue) {
@@ -1042,6 +1148,17 @@ void TCPDataDispatcher::process()
}
}
+
+std::vector<TCPConnection::stats_t> TCPDataDispatcher::get_stats() const
+{
+ std::vector<TCPConnection::stats_t> s;
+ auto lock = unique_lock<mutex>(m_mutex);
+ for (const auto& conn : m_connections) {
+ s.push_back(conn.get_stats());
+ }
+ return s;
+}
+
TCPReceiveServer::TCPReceiveServer(size_t blocksize) :
m_blocksize(blocksize)
{
@@ -1098,7 +1215,7 @@ void TCPReceiveServer::process()
}
else {
buf.resize(r);
- m_queue.push(make_shared<TCPReceiveMessageData>(move(buf)));
+ m_queue.push(make_shared<TCPReceiveMessageData>(std::move(buf)));
}
}
catch (const TCPSocket::Interrupted&) {
@@ -1139,7 +1256,7 @@ TCPSendClient::~TCPSendClient()
}
}
-void TCPSendClient::sendall(const std::vector<uint8_t>& buffer)
+TCPSendClient::ErrorStats TCPSendClient::sendall(const std::vector<uint8_t>& buffer)
{
if (not m_running) {
throw runtime_error(m_exception_data);
@@ -1151,6 +1268,17 @@ void TCPSendClient::sendall(const std::vector<uint8_t>& buffer)
vector<uint8_t> discard;
m_queue.try_pop(discard);
}
+
+ TCPSendClient::ErrorStats es;
+ es.num_reconnects = m_num_reconnects.load();
+
+ es.has_seen_new_errors = es.num_reconnects != m_num_reconnects_prev;
+ m_num_reconnects_prev = es.num_reconnects;
+
+ auto lock = unique_lock<mutex>(m_error_mutex);
+ es.last_error = m_last_error;
+
+ return es;
}
void TCPSendClient::process()
@@ -1172,12 +1300,16 @@ void TCPSendClient::process()
}
else {
try {
+ m_num_reconnects.fetch_add(1, std::memory_order_seq_cst);
m_sock.connect(m_hostname, m_port);
m_is_connected = true;
}
catch (const runtime_error& e) {
m_is_connected = false;
this_thread::sleep_for(chrono::seconds(1));
+
+ auto lock = unique_lock<mutex>(m_error_mutex);
+ m_last_error = e.what();
}
}
}