aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/Socket.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/Socket.cpp')
-rw-r--r--contrib/Socket.cpp103
1 files changed, 80 insertions, 23 deletions
diff --git a/contrib/Socket.cpp b/contrib/Socket.cpp
index 1281066..bcffb07 100644
--- a/contrib/Socket.cpp
+++ b/contrib/Socket.cpp
@@ -24,7 +24,7 @@
#include "Socket.h"
-#include <iostream>
+#include <stdexcept>
#include <cstdio>
#include <cstring>
#include <cerrno>
@@ -106,16 +106,20 @@ UDPSocket::UDPSocket(UDPSocket&& other)
{
m_sock = other.m_sock;
m_port = other.m_port;
+ m_multicast_source = other.m_multicast_source;
other.m_port = 0;
other.m_sock = INVALID_SOCKET;
+ other.m_multicast_source = "";
}
const UDPSocket& UDPSocket::operator=(UDPSocket&& other)
{
m_sock = other.m_sock;
m_port = other.m_port;
+ m_multicast_source = other.m_multicast_source;
other.m_port = 0;
other.m_sock = INVALID_SOCKET;
+ other.m_multicast_source = "";
return *this;
}
@@ -144,6 +148,7 @@ void UDPSocket::reinit(int port, const std::string& name)
// No need to bind to a given port, creating the
// socket is enough
m_sock = ::socket(AF_INET, SOCK_DGRAM, 0);
+ post_init();
return;
}
@@ -180,6 +185,7 @@ void UDPSocket::reinit(int port, const std::string& name)
if (::bind(sfd, rp->ai_addr, rp->ai_addrlen) == 0) {
m_sock = sfd;
+ post_init();
break;
}
@@ -193,6 +199,14 @@ void UDPSocket::reinit(int port, const std::string& name)
}
}
+void UDPSocket::post_init() {
+ int pktinfo = 1;
+ if (setsockopt(m_sock, IPPROTO_IP, IP_PKTINFO, &pktinfo, sizeof(pktinfo)) == SOCKET_ERROR) {
+ throw runtime_error(string("Can't request pktinfo: ") + strerror(errno));
+ }
+
+}
+
void UDPSocket::init_receive_multicast(int port, const string& local_if_addr, const string& mcastaddr)
{
if (m_sock != INVALID_SOCKET) {
@@ -201,9 +215,10 @@ void UDPSocket::init_receive_multicast(int port, const string& local_if_addr, co
m_port = port;
m_sock = ::socket(AF_INET, SOCK_DGRAM, 0);
+ post_init();
int reuse_setting = 1;
- if (setsockopt(m_sock, SOL_SOCKET, SO_REUSEADDR, &reuse_setting, sizeof(reuse_setting)) == -1) {
+ if (setsockopt(m_sock, SOL_SOCKET, SO_REUSEADDR, &reuse_setting, sizeof(reuse_setting)) == SOCKET_ERROR) {
throw runtime_error("Can't reuse address");
}
@@ -216,8 +231,8 @@ void UDPSocket::init_receive_multicast(int port, const string& local_if_addr, co
throw runtime_error(string("Could not bind: ") + strerror(errno));
}
- joinGroup(mcastaddr.c_str(), local_if_addr.c_str());
-
+ m_multicast_source = mcastaddr;
+ join_group(mcastaddr.c_str(), local_if_addr.c_str());
}
@@ -240,16 +255,26 @@ UDPSocket::~UDPSocket()
UDPPacket UDPSocket::receive(size_t max_size)
{
+ struct sockaddr_in addr;
+ struct msghdr msg;
+ struct iovec iov;
+ constexpr size_t BUFFER_SIZE = 1024;
+ char control_buffer[BUFFER_SIZE];
+ struct cmsghdr *cmsg;
+
UDPPacket packet(max_size);
- socklen_t addrSize;
- addrSize = sizeof(*packet.address.as_sockaddr());
- ssize_t ret = recvfrom(m_sock,
- packet.buffer.data(),
- packet.buffer.size(),
- 0,
- packet.address.as_sockaddr(),
- &addrSize);
+ memset(&msg, 0, sizeof(msg));
+ msg.msg_name = &addr;
+ msg.msg_namelen = sizeof(addr);
+ msg.msg_iov = &iov;
+ iov.iov_base = packet.buffer.data();
+ iov.iov_len = packet.buffer.size();
+ msg.msg_iovlen = 1;
+ msg.msg_control = control_buffer;
+ msg.msg_controllen = sizeof(control_buffer);
+
+ ssize_t ret = recvmsg(m_sock, &msg, 0);
if (ret == SOCKET_ERROR) {
packet.buffer.resize(0);
@@ -260,12 +285,42 @@ UDPPacket UDPSocket::receive(size_t max_size)
if (errno == EAGAIN or errno == EWOULDBLOCK)
#endif
{
- return 0;
+ return packet;
}
throw runtime_error(string("Can't receive data: ") + strerror(errno));
}
- packet.buffer.resize(ret);
+ struct in_pktinfo *pktinfo = nullptr;
+ for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+ if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_PKTINFO) {
+ pktinfo = (struct in_pktinfo *)CMSG_DATA(cmsg);
+ break;
+ }
+ }
+
+ if (pktinfo) {
+ char src_addr[INET_ADDRSTRLEN];
+ char dst_addr[INET_ADDRSTRLEN];
+ inet_ntop(AF_INET, &(addr.sin_addr), src_addr, INET_ADDRSTRLEN);
+ inet_ntop(AF_INET, &(pktinfo->ipi_addr), dst_addr, INET_ADDRSTRLEN);
+ //fprintf(stderr, "Received packet from %s to %s: %zu\n", src_addr, dst_addr, ret);
+
+ memcpy(&packet.address.addr, &addr, sizeof(addr));
+
+ if (m_multicast_source.empty() or
+ strcmp(dst_addr, m_multicast_source.c_str()) == 0) {
+ packet.buffer.resize(ret);
+ }
+ else {
+ // Ignore packet for different multicast group
+ packet.buffer.resize(0);
+ }
+ }
+ else {
+ //fprintf(stderr, "No pktinfo: %zu\n", ret);
+ packet.buffer.resize(ret);
+ }
+
return packet;
}
@@ -297,7 +352,7 @@ void UDPSocket::send(const std::string& data, InetAddress destination)
}
}
-void UDPSocket::joinGroup(const char* groupname, const char* if_addr)
+void UDPSocket::join_group(const char* groupname, const char* if_addr)
{
ip_mreqn group;
if ((group.imr_multiaddr.s_addr = inet_addr(groupname)) == INADDR_NONE) {
@@ -316,7 +371,7 @@ void UDPSocket::joinGroup(const char* groupname, const char* if_addr)
group.imr_ifindex = 0;
if (setsockopt(m_sock, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group))
== SOCKET_ERROR) {
- throw runtime_error(string("Can't join multicast group") + strerror(errno));
+ throw runtime_error(string("Can't join multicast group: ") + strerror(errno));
}
}
@@ -324,12 +379,12 @@ void UDPSocket::setMulticastSource(const char* source_addr)
{
struct in_addr addr;
if (inet_aton(source_addr, &addr) == 0) {
- throw runtime_error(string("Can't parse source address") + strerror(errno));
+ throw runtime_error(string("Can't parse source address: ") + strerror(errno));
}
if (setsockopt(m_sock, IPPROTO_IP, IP_MULTICAST_IF, &addr, sizeof(addr))
== SOCKET_ERROR) {
- throw runtime_error(string("Can't set source address") + strerror(errno));
+ throw runtime_error(string("Can't set source address: ") + strerror(errno));
}
}
@@ -392,11 +447,13 @@ vector<UDPReceiver::ReceivedPacket> UDPReceiver::receive(int timeout_ms)
for (size_t i = 0; i < m_sockets.size(); i++) {
if (fds[i].revents & POLLIN) {
auto p = m_sockets[i].receive(2048); // This is larger than the usual MTU
- ReceivedPacket rp;
- rp.packetdata = move(p.buffer);
- rp.received_from = move(p.address);
- rp.port_received_on = m_sockets[i].getPort();
- received.push_back(move(rp));
+ if (not p.buffer.empty()) {
+ ReceivedPacket rp;
+ rp.packetdata = std::move(p.buffer);
+ rp.received_from = std::move(p.address);
+ rp.port_received_on = m_sockets[i].getPort();
+ received.push_back(std::move(rp));
+ }
}
}