diff options
Diffstat (limited to 'lib/Socket.cpp')
| -rw-r--r-- | lib/Socket.cpp | 817 | 
1 files changed, 817 insertions, 0 deletions
| diff --git a/lib/Socket.cpp b/lib/Socket.cpp new file mode 100644 index 0000000..9b404eb --- /dev/null +++ b/lib/Socket.cpp @@ -0,0 +1,817 @@ +/* +   Copyright (C) 2003, 2004, 2005, 2006, 2007, 2008, 2009 Her Majesty the +   Queen in Right of Canada (Communications Research Center Canada) + +   Copyright (C) 2019 +   Matthias P. Braendli, matthias.braendli@mpb.li + +    http://www.opendigitalradio.org +   */ +/* +   This program is free software: you can redistribute it and/or modify +   it under the terms of the GNU General Public License as published by +   the Free Software Foundation, either version 3 of the License, or +   (at your option) any later version. + +   This program is distributed in the hope that it will be useful, +   but WITHOUT ANY WARRANTY; without even the implied warranty of +   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +   GNU General Public License for more details. + +   You should have received a copy of the GNU General Public License +   along with this program.  If not, see <https://www.gnu.org/licenses/>. +*/ + +#include "Socket.h" + +#include <iostream> +#include <cstdio> +#include <cstring> +#include <cerrno> +#include <fcntl.h> +#include <poll.h> + +namespace Socket { + +using namespace std; + +void InetAddress::resolveUdpDestination(const std::string& destination, int port) +{ +    char service[NI_MAXSERV]; +    snprintf(service, NI_MAXSERV-1, "%d", port); + +    struct addrinfo hints; +    memset(&hints, 0, sizeof(struct addrinfo)); +    hints.ai_family = AF_INET; +    hints.ai_socktype = SOCK_DGRAM; /* Datagram socket */ +    hints.ai_flags = 0; +    hints.ai_protocol = 0; + +    struct addrinfo *result, *rp; +    int s = getaddrinfo(destination.c_str(), service, &hints, &result); +    if (s != 0) { +        throw runtime_error(string("getaddrinfo failed: ") + gai_strerror(s)); +    } + +    for (rp = result; rp != nullptr; rp = rp->ai_next) { +        // Take the first result +        memcpy(&addr, rp->ai_addr, rp->ai_addrlen); +        break; +    } + +    freeaddrinfo(result); + +    if (rp == nullptr) { +        throw runtime_error("Could not resolve"); +    } +} + +UDPPacket::UDPPacket() { } + +UDPPacket::UDPPacket(size_t initSize) : +    buffer(initSize) +{ } + + +UDPSocket::UDPSocket() : +    m_sock(INVALID_SOCKET) +{ +    reinit(0, ""); +} + +UDPSocket::UDPSocket(int port) : +    m_sock(INVALID_SOCKET) +{ +    reinit(port, ""); +} + +UDPSocket::UDPSocket(int port, const std::string& name) : +    m_sock(INVALID_SOCKET) +{ +    reinit(port, name); +} + + +void UDPSocket::setBlocking(bool block) +{ +    int res = fcntl(m_sock, F_SETFL, block ? 0 : O_NONBLOCK); +    if (res == -1) { +        throw runtime_error(string("Can't change blocking state of socket: ") + strerror(errno)); +    } +} + +void UDPSocket::reinit(int port) +{ +    return reinit(port, ""); +} + +void UDPSocket::reinit(int port, const std::string& name) +{ +    if (m_sock != INVALID_SOCKET) { +        ::close(m_sock); +    } + +    if (port == 0) { +        // No need to bind to a given port, creating the +        // socket is enough +        m_sock = ::socket(AF_INET, SOCK_DGRAM, 0); +        return; +    } + +    char service[NI_MAXSERV]; +    snprintf(service, NI_MAXSERV-1, "%d", port); + +    struct addrinfo hints; +    memset(&hints, 0, sizeof(struct addrinfo)); +    hints.ai_family = AF_INET; +    hints.ai_socktype = SOCK_DGRAM; /* Datagram socket */ +    hints.ai_flags = AI_PASSIVE;    /* For wildcard IP address */ +    hints.ai_protocol = 0;          /* Any protocol */ +    hints.ai_canonname = nullptr; +    hints.ai_addr = nullptr; +    hints.ai_next = nullptr; + +    struct addrinfo *result, *rp; +    int s = getaddrinfo(name.empty() ? nullptr : name.c_str(), +            port == 0 ? nullptr : service, +            &hints, &result); +    if (s != 0) { +        throw runtime_error(string("getaddrinfo failed: ") + gai_strerror(s)); +    } + +    /* getaddrinfo() returns a list of address structures. +       Try each address until we successfully bind(2). +       If socket(2) (or bind(2)) fails, we (close the socket +       and) try the next address. */ +    for (rp = result; rp != nullptr; rp = rp->ai_next) { +        int sfd = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); +        if (sfd == -1) { +            continue; +        } + +        if (::bind(sfd, rp->ai_addr, rp->ai_addrlen) == 0) { +            m_sock = sfd; +            break; +        } + +        ::close(sfd); +    } + +    freeaddrinfo(result); + +    if (rp == nullptr) { +        throw runtime_error("Could not bind"); +    } +} + +void UDPSocket::close() +{ +    if (m_sock != INVALID_SOCKET) { +        ::close(m_sock); +    } + +    m_sock = INVALID_SOCKET; +} + +UDPSocket::~UDPSocket() +{ +    if (m_sock != INVALID_SOCKET) { +        ::close(m_sock); +    } +} + + +UDPPacket UDPSocket::receive(size_t max_size) +{ +    UDPPacket packet(max_size); +    socklen_t addrSize; +    addrSize = sizeof(*packet.address.as_sockaddr()); +    ssize_t ret = recvfrom(m_sock, +            packet.buffer.data(), +            packet.buffer.size(), +            0, +            packet.address.as_sockaddr(), +            &addrSize); + +    if (ret == SOCKET_ERROR) { +        packet.buffer.resize(0); + +        // This suppresses the -Wlogical-op warning +#if EAGAIN == EWOULDBLOCK +        if (errno == EAGAIN) { +#else +        if (errno == EAGAIN or errno == EWOULDBLOCK) { +#endif +            return 0; +        } +        throw runtime_error(string("Can't receive data: ") + strerror(errno)); +    } + +    packet.buffer.resize(ret); +    return packet; +} + +void UDPSocket::send(UDPPacket& packet) +{ +    const int ret = sendto(m_sock, packet.buffer.data(), packet.buffer.size(), 0, +            packet.address.as_sockaddr(), sizeof(*packet.address.as_sockaddr())); +    if (ret == SOCKET_ERROR && errno != ECONNREFUSED) { +        throw runtime_error(string("Can't send UDP packet: ") + strerror(errno)); +    } +} + + +void UDPSocket::send(const std::vector<uint8_t>& data, InetAddress destination) +{ +    const int ret = sendto(m_sock, data.data(), data.size(), 0, +            destination.as_sockaddr(), sizeof(*destination.as_sockaddr())); +    if (ret == SOCKET_ERROR && errno != ECONNREFUSED) { +        throw runtime_error(string("Can't send UDP packet: ") + strerror(errno)); +    } +} + +void UDPSocket::joinGroup(const char* groupname, const char* if_addr) +{ +    ip_mreqn group; +    if ((group.imr_multiaddr.s_addr = inet_addr(groupname)) == INADDR_NONE) { +        throw runtime_error("Cannot convert multicast group name"); +    } +    if (!IN_MULTICAST(ntohl(group.imr_multiaddr.s_addr))) { +        throw runtime_error("Group name is not a multicast address"); +    } + +    if (if_addr) { +        group.imr_address.s_addr = inet_addr(if_addr); +    } +    else { +        group.imr_address.s_addr = htons(INADDR_ANY); +    } +    group.imr_ifindex = 0; +    if (setsockopt(m_sock, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)) +            == SOCKET_ERROR) { +        throw runtime_error(string("Can't join multicast group") + strerror(errno)); +    } +} + +void UDPSocket::setMulticastSource(const char* source_addr) +{ +    struct in_addr addr; +    if (inet_aton(source_addr, &addr) == 0) { +        throw runtime_error(string("Can't parse source address") + strerror(errno)); +    } + +    if (setsockopt(m_sock, IPPROTO_IP, IP_MULTICAST_IF, &addr, sizeof(addr)) +            == SOCKET_ERROR) { +        throw runtime_error(string("Can't set source address") + strerror(errno)); +    } +} + +void UDPSocket::setMulticastTTL(int ttl) +{ +    if (setsockopt(m_sock, IPPROTO_IP, IP_MULTICAST_TTL, &ttl, sizeof(ttl)) +            == SOCKET_ERROR) { +        throw runtime_error(string("Can't set multicast ttl") + strerror(errno)); +    } +} + +UDPReceiver::~UDPReceiver() { +    m_stop = true; +    m_sock.close(); +    if (m_thread.joinable()) { +        m_thread.join(); +    } +} + +void UDPReceiver::start(int port, const string& bindto, const string& mcastaddr, size_t max_packets_queued) { +    m_port = port; +    m_bindto = bindto; +    m_mcastaddr = mcastaddr; +    m_max_packets_queued = max_packets_queued; +    m_thread = std::thread(&UDPReceiver::m_run, this); +} + +std::vector<uint8_t> UDPReceiver::get_packet_buffer() +{ +    if (m_stop) { +        throw runtime_error("UDP Receiver not running"); +    } + +    UDPPacket p; +    m_packets.wait_and_pop(p); + +    return p.buffer; +} + +void UDPReceiver::m_run() +{ +    // Ensure that stop is set to true in case of exception or return +    struct SetStopOnDestruct { +        SetStopOnDestruct(atomic<bool>& stop) : m_stop(stop) {} +        ~SetStopOnDestruct() { m_stop = true; } +        private: atomic<bool>& m_stop; +    } autoSetStop(m_stop); + +    if (IN_MULTICAST(ntohl(inet_addr(m_mcastaddr.c_str())))) { +        m_sock.reinit(m_port, m_mcastaddr); +        m_sock.setMulticastSource(m_bindto.c_str()); +        m_sock.joinGroup(m_mcastaddr.c_str(), m_bindto.c_str()); +    } +    else { +        m_sock.reinit(m_port, m_bindto); +    } + +    while (not m_stop) { +        constexpr size_t packsize = 8192; +        try { +            auto packet = m_sock.receive(packsize); +            if (packet.buffer.size() == packsize) { +                // TODO replace fprintf +                fprintf(stderr, "Warning, possible UDP truncation\n"); +            } + +            // If this blocks, the UDP socket will lose incoming packets +            m_packets.push_wait_if_full(packet, m_max_packets_queued); +        } +        catch (const std::runtime_error& e) { +            // TODO replace fprintf +            // TODO handle intr +            fprintf(stderr, "Socket error: %s\n", e.what()); +            m_stop = true; +        } +    } +} + + +TCPSocket::TCPSocket() +{ +} + +TCPSocket::~TCPSocket() +{ +    if (m_sock != -1) { +        ::close(m_sock); +    } +} + +TCPSocket::TCPSocket(TCPSocket&& other) : +    m_sock(other.m_sock), +    m_remote_address(other.m_remote_address) +{ +    if (other.m_sock != -1) { +        other.m_sock = -1; +    } +} + +TCPSocket& TCPSocket::operator=(TCPSocket&& other) +{ +    m_sock = other.m_sock; +    m_remote_address = other.m_remote_address; + +    if (other.m_sock != -1) { +        other.m_sock = -1; +    } + +    return *this; +} + +bool TCPSocket::valid() const +{ +    return m_sock != -1; +} + +void TCPSocket::connect(const std::string& hostname, int port) +{ +    if (m_sock != INVALID_SOCKET) { +        throw std::logic_error("You may only connect an invalid TCPSocket"); +    } + +    char service[NI_MAXSERV]; +    snprintf(service, NI_MAXSERV-1, "%d", port); + +    /* Obtain address(es) matching host/port */ +    struct addrinfo hints; +    memset(&hints, 0, sizeof(struct addrinfo)); +    hints.ai_family = AF_INET; +    hints.ai_socktype = SOCK_STREAM; +    hints.ai_flags = 0; +    hints.ai_protocol = 0; + +    struct addrinfo *result, *rp; +    int s = getaddrinfo(hostname.c_str(), service, &hints, &result); +    if (s != 0) { +        throw runtime_error(string("getaddrinfo failed: ") + gai_strerror(s)); +    } + +    /* getaddrinfo() returns a list of address structures. +       Try each address until we successfully connect(2). +       If socket(2) (or connect(2)) fails, we (close the socket +       and) try the next address. */ + +    for (rp = result; rp != nullptr; rp = rp->ai_next) { +        int sfd = ::socket(rp->ai_family, rp->ai_socktype, +                rp->ai_protocol); +        if (sfd == -1) +            continue; + +        int ret = ::connect(sfd, rp->ai_addr, rp->ai_addrlen); +        if (ret != -1 or (ret == -1 and errno == EINPROGRESS)) { +            // As the TCPClient could set the socket to nonblocking, we +            // must handle EINPROGRESS here +            m_sock = sfd; +            break; +        } + +        ::close(sfd); +    } + +    if (m_sock != INVALID_SOCKET) { +#if defined(HAVE_SO_NOSIGPIPE) +        int val = 1; +        if (setsockopt(m_sock, SOL_SOCKET, SO_NOSIGPIPE, &val, sizeof(val)) +                == SOCKET_ERROR) { +            throw std::runtime_error("Can't set SO_NOSIGPIPE"); +        } +#endif +    } + +    freeaddrinfo(result);           /* No longer needed */ + +    if (rp == nullptr) { +        throw runtime_error("Could not connect"); +    } + +} + +void TCPSocket::listen(int port, const string& name) +{ +    if (m_sock != INVALID_SOCKET) { +        throw std::logic_error("You may only listen with an invalid TCPSocket"); +    } + +    char service[NI_MAXSERV]; +    snprintf(service, NI_MAXSERV-1, "%d", port); + +    struct addrinfo hints; +    memset(&hints, 0, sizeof(struct addrinfo)); +    hints.ai_family = AF_INET; +    hints.ai_socktype = SOCK_STREAM; +    hints.ai_flags = AI_PASSIVE;    /* For wildcard IP address */ +    hints.ai_protocol = 0; +    hints.ai_canonname = nullptr; +    hints.ai_addr = nullptr; +    hints.ai_next = nullptr; + +    struct addrinfo *result, *rp; +    int s = getaddrinfo(name.empty() ? nullptr : name.c_str(), service, &hints, &result); +    if (s != 0) { +        throw runtime_error(string("getaddrinfo failed: ") + gai_strerror(s)); +    } + +    /* getaddrinfo() returns a list of address structures. +       Try each address until we successfully bind(2). +       If socket(2) (or bind(2)) fails, we (close the socket +       and) try the next address. */ +    for (rp = result; rp != nullptr; rp = rp->ai_next) { +        int sfd = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); +        if (sfd == -1) { +            continue; +        } + +        if (::bind(sfd, rp->ai_addr, rp->ai_addrlen) == 0) { +            m_sock = sfd; +            break; +        } + +        ::close(sfd); +    } + +    freeaddrinfo(result); + +#if defined(HAVE_SO_NOSIGPIPE) +    int val = 1; +    if (setsockopt(m_sock, SOL_SOCKET, SO_NOSIGPIPE, +                &val, sizeof(val)) < 0) { +        throw std::runtime_error("Can't set SO_NOSIGPIPE"); +    } +#endif + +    if (rp == nullptr) { +        throw runtime_error("Could not bind"); +    } +} + +void TCPSocket::close() +{ +    ::close(m_sock); +    m_sock = -1; +} + +TCPSocket TCPSocket::accept(int timeout_ms) +{ +    if (timeout_ms == 0) { +        InetAddress remote_addr; +        socklen_t client_len = sizeof(remote_addr.addr); +        int sockfd = ::accept(m_sock, remote_addr.as_sockaddr(), &client_len); +        TCPSocket s(sockfd, remote_addr); +        return s; +    } +    else { +        struct pollfd fds[1]; +        fds[0].fd = m_sock; +        fds[0].events = POLLIN; + +        int retval = poll(fds, 1, timeout_ms); + +        if (retval == -1) { +            std::string errstr(strerror(errno)); +            throw std::runtime_error("TCP Socket accept error: " + errstr); +        } +        else if (retval > 0) { +            InetAddress remote_addr; +            socklen_t client_len = sizeof(remote_addr.addr); +            int sockfd = ::accept(m_sock, remote_addr.as_sockaddr(), &client_len); +            TCPSocket s(sockfd, remote_addr); +            return s; +        } +        else { +            TCPSocket s(-1); +            return s; +        } +    } +} + +ssize_t TCPSocket::sendall(const void *buffer, size_t buflen) +{ +    uint8_t *buf = (uint8_t*)buffer; +    while (buflen > 0) { +        /* On Linux, the MSG_NOSIGNAL flag ensures that the process +         * would not receive a SIGPIPE and die. +         * Other systems have SO_NOSIGPIPE set on the socket for the +         * same effect. */ +#if defined(HAVE_MSG_NOSIGNAL) +        const int flags = MSG_NOSIGNAL; +#else +        const int flags = 0; +#endif +        ssize_t sent = ::send(m_sock, buf, buflen, flags); +        if (sent < 0) { +            return -1; +        } +        else { +            buf += sent; +            buflen -= sent; +        } +    } +    return buflen; +} + +ssize_t TCPSocket::send(const void* data, size_t size, int timeout_ms) +{ +    if (timeout_ms) { +        struct pollfd fds[1]; +        fds[0].fd = m_sock; +        fds[0].events = POLLOUT; + +        const int retval = poll(fds, 1, timeout_ms); + +        if (retval == -1) { +            throw std::runtime_error(string("TCP Socket send error on poll(): ") + strerror(errno)); +        } +        else if (retval == 0) { +            // Timed out +            return 0; +        } +    } + +    /* On Linux, the MSG_NOSIGNAL flag ensures that the process would not +     * receive a SIGPIPE and die. +     * Other systems have SO_NOSIGPIPE set on the socket for the same effect. */ +#if defined(HAVE_MSG_NOSIGNAL) +    const int flags = MSG_NOSIGNAL; +#else +    const int flags = 0; +#endif +    const ssize_t ret = ::send(m_sock, (const char*)data, size, flags); + +    if (ret == SOCKET_ERROR) { +            throw std::runtime_error(string("TCP Socket send error: ") + strerror(errno)); +    } +    return ret; +} + +ssize_t TCPSocket::recv(void *buffer, size_t length, int flags) +{ +    ssize_t ret = ::recv(m_sock, buffer, length, flags); +    if (ret == -1) { +        std::string errstr(strerror(errno)); +        throw std::runtime_error("TCP receive error: " + errstr); +    } +    return ret; +} + +ssize_t TCPSocket::recv(void *buffer, size_t length, int flags, int timeout_ms) +{ +    struct pollfd fds[1]; +    fds[0].fd = m_sock; +    fds[0].events = POLLIN; + +    int retval = poll(fds, 1, timeout_ms); + +    if (retval == -1 and errno == EINTR) { +        throw Interrupted(); +    } +    else if (retval == -1) { +        std::string errstr(strerror(errno)); +        throw std::runtime_error("TCP receive with poll() error: " + errstr); +    } +    else if (retval > 0 and (fds[0].revents | POLLIN)) { +        ssize_t ret = ::recv(m_sock, buffer, length, flags); +        if (ret == -1) { +            if (errno == ECONNREFUSED) { +                return 0; +            } +            std::string errstr(strerror(errno)); +            throw std::runtime_error("TCP receive after poll() error: " + errstr); +        } +        return ret; +    } +    else { +        throw Timeout(); +    } +} + +TCPSocket::TCPSocket(int sockfd) : +    m_sock(sockfd), +    m_remote_address() +{ } + +TCPSocket::TCPSocket(int sockfd, InetAddress remote_address) : +    m_sock(sockfd), +    m_remote_address(remote_address) +{ } + +void TCPClient::connect(const std::string& hostname, int port) +{ +    m_hostname = hostname; +    m_port = port; +    reconnect(); +} + +ssize_t TCPClient::recv(void *buffer, size_t length, int flags, int timeout_ms) +{ +    try { +        ssize_t ret = m_sock.recv(buffer, length, flags, timeout_ms); + +        if (ret == 0) { +            m_sock.close(); + +            TCPSocket newsock; +            m_sock = std::move(newsock); +            reconnect(); +        } + +        return ret; +    } +    catch (const TCPSocket::Interrupted&) { +        return -1; +    } +    catch (const TCPSocket::Timeout&) { +        return 0; +    } + +    return 0; +} + +void TCPClient::reconnect() +{ +    int flags = fcntl(m_sock.m_sock, F_GETFL); +    if (fcntl(m_sock.m_sock, F_SETFL, flags | O_NONBLOCK) == -1) { +        std::string errstr(strerror(errno)); +        throw std::runtime_error("TCP: Could not set O_NONBLOCK: " + errstr); +    } + +    m_sock.connect(m_hostname, m_port); +} + +TCPConnection::TCPConnection(TCPSocket&& sock) : +            queue(), +            m_running(true), +            m_sender_thread(), +            m_sock(move(sock)) +{ +#if MISSING_OWN_ADDR +    auto own_addr = m_sock.getOwnAddress(); +    auto addr = m_sock.getRemoteAddress(); +    etiLog.level(debug) << "New TCP Connection on port " << +        own_addr.getPort() << " from " << +        addr.getHostAddress() << ":" << addr.getPort(); +#endif +    m_sender_thread = std::thread(&TCPConnection::process, this); +} + +TCPConnection::~TCPConnection() +{ +    m_running = false; +    vector<uint8_t> termination_marker; +    queue.push(termination_marker); +    m_sender_thread.join(); +} + +void TCPConnection::process() +{ +    while (m_running) { +        vector<uint8_t> data; +        queue.wait_and_pop(data); + +        if (data.empty()) { +            // empty vector is the termination marker +            m_running = false; +            break; +        } + +        try { +            ssize_t remaining = data.size(); +            const uint8_t *buf = reinterpret_cast<const uint8_t*>(data.data()); +            const int timeout_ms = 10; // Less than one ETI frame + +            while (m_running and remaining > 0) { +                const ssize_t sent = m_sock.send(buf, remaining, timeout_ms); +                if (sent < 0 or sent > remaining) { +                    throw std::logic_error("Invalid TCPSocket::send() return value"); +                } +                remaining -= sent; +                buf += sent; +            } +        } +        catch (const std::runtime_error& e) { +            m_running = false; +        } +    } + +#if MISSING_OWN_ADDR +    auto own_addr = m_sock.getOwnAddress(); +    auto addr = m_sock.getRemoteAddress(); +    etiLog.level(debug) << "Dropping TCP Connection on port " << +        own_addr.getPort() << " from " << +        addr.getHostAddress() << ":" << addr.getPort(); +#endif +} + + +TCPDataDispatcher::TCPDataDispatcher(size_t max_queue_size) : +    m_max_queue_size(max_queue_size) +{ +} + +TCPDataDispatcher::~TCPDataDispatcher() +{ +    m_running = false; +    m_connections.clear(); +    m_listener_socket.close(); +    if (m_listener_thread.joinable()) { +        m_listener_thread.join(); +    } +} + +void TCPDataDispatcher::start(int port, const string& address) +{ +    m_listener_socket.listen(port, address); + +    m_running = true; +    m_listener_thread = std::thread(&TCPDataDispatcher::process, this); +} + +void TCPDataDispatcher::write(const vector<uint8_t>& data) +{ +    if (not m_running) { +        throw runtime_error(m_exception_data); +    } + +    for (auto& connection : m_connections) { +        connection.queue.push(data); +    } + +    m_connections.remove_if( +            [&](const TCPConnection& conn){ return conn.queue.size() > m_max_queue_size; }); +} + +void TCPDataDispatcher::process() +{ +    try { +        const int timeout_ms = 1000; + +        while (m_running) { +            // Add a new TCPConnection to the list, constructing it from the client socket +            auto sock = m_listener_socket.accept(timeout_ms); +            if (sock.valid()) { +                m_connections.emplace(m_connections.begin(), move(sock)); +            } +        } +    } +    catch (const std::runtime_error& e) { +        m_exception_data = string("TCPDataDispatcher error: ") + e.what(); +        m_running = false; +    } +} + +} | 
