//
// Copyright 2016 Ettus Research LLC
// Copyright 2018 Ettus Research, a National Instruments Company
//
// SPDX-License-Identifier: GPL-3.0-or-later
//

#include <uhd/exception.hpp>
#include <uhd/transport/bounded_buffer.hpp>
#include <uhd/transport/muxed_zero_copy_if.hpp>
#include <uhd/utils/safe_call.hpp>
#include <boost/enable_shared_from_this.hpp>
#include <boost/make_shared.hpp>
#include <boost/thread.hpp>
#include <boost/thread/locks.hpp>
#include <map>

using namespace uhd;
using namespace uhd::transport;

class muxed_zero_copy_if_impl
    : public muxed_zero_copy_if,
      public boost::enable_shared_from_this<muxed_zero_copy_if_impl>
{
public:
    typedef boost::shared_ptr<muxed_zero_copy_if_impl> sptr;

    muxed_zero_copy_if_impl(zero_copy_if::sptr base_xport,
        stream_classifier_fn classify_fn,
        size_t max_streams)
        : _base_xport(base_xport)
        , _classify(classify_fn)
        , _max_num_streams(max_streams)
        , _num_dropped_frames(0)
    {
        // Create the receive thread to poll the underlying transport
        // and classify packets into queues
        _recv_thread =
            boost::thread(boost::bind(&muxed_zero_copy_if_impl::_update_queues, this));
    }

    virtual ~muxed_zero_copy_if_impl()
    {
        UHD_SAFE_CALL(
            // Interrupt buffer updater loop
            _recv_thread.interrupt();
            // Wait for loop to finish
            // No timeout on join. The recv loop is guaranteed
            // to terminate in a reasonable amount of time because
            // there are no timed blocks on the underlying.
            _recv_thread.join();
            // Flush base transport
            while (_base_xport->get_recv_buff(0.0001)) /*NOP*/;
            // Release child streams
            // Note that this will not delete or flush the child streams
            // until the owners of the streams have released the respective
            // shared pointers. This ensures that packets are not dropped.
            _streams.clear(););
    }

    virtual zero_copy_if::sptr make_stream(const uint32_t stream_num)
    {
        boost::lock_guard<boost::mutex> lock(_mutex);
        if (_streams.size() >= _max_num_streams) {
            throw uhd::runtime_error("muxed_zero_copy_if: stream capacity exceeded. "
                                     "cannot create more streams.");
        }
        // Only allocate a portion of the base transport's frames to each stream
        // to prevent all streams from attempting to use all the frames.
        stream_impl::sptr stream =
            boost::make_shared<stream_impl>(this->shared_from_this(),
                stream_num,
                _base_xport->get_num_send_frames(),
                _base_xport->get_num_recv_frames());
        _streams[stream_num] = stream;
        return stream;
    }

    virtual size_t get_num_dropped_frames() const
    {
        return _num_dropped_frames;
    }

    void remove_stream(const uint32_t stream_num)
    {
        boost::lock_guard<boost::mutex> lock(_mutex);
        _streams.erase(stream_num);
    }

private:
    /*
     * @class stream_mrb is used to copy the data and release the original
     * managed receive buffer back to the base transport.
     */
    class stream_mrb : public managed_recv_buffer
    {
    public:
        stream_mrb(size_t size) : _buff(new char[size]) {}

        ~stream_mrb()
        {
            delete[] _buff;
        }

        void release() {}

        UHD_INLINE sptr get_new(char* buff, size_t len)
        {
            memcpy(_buff, buff, len);
            return make(this, _buff, len);
        }

    private:
        char* _buff;
    };

    class stream_impl : public zero_copy_if
    {
    public:
        typedef boost::shared_ptr<stream_impl> sptr;
        typedef boost::weak_ptr<stream_impl> wptr;

        stream_impl(muxed_zero_copy_if_impl::sptr muxed_xport,
            const uint32_t stream_num,
            const size_t num_send_frames,
            const size_t num_recv_frames)
            : _stream_num(stream_num)
            , _muxed_xport(muxed_xport)
            , _num_send_frames(num_send_frames)
            , _send_frame_size(_muxed_xport->base_xport()->get_send_frame_size())
            , _num_recv_frames(num_recv_frames)
            , _recv_frame_size(_muxed_xport->base_xport()->get_recv_frame_size())
            , _buff_queue(num_recv_frames)
            , _buffers(num_recv_frames)
            , _buffer_index(0)
        {
            for (size_t i = 0; i < num_recv_frames; i++) {
                _buffers[i] = boost::make_shared<stream_mrb>(_recv_frame_size);
            }
        }

        ~stream_impl(void)
        {
            // First remove the stream from muxed transport
            // so no more frames are pushed in
            _muxed_xport->remove_stream(_stream_num);
            // Flush the transport
            managed_recv_buffer::sptr buff;
            while (_buff_queue.pop_with_haste(buff)) {
                // NOP
            }
        }

        size_t get_num_recv_frames(void) const
        {
            return _num_recv_frames;
        }

        size_t get_recv_frame_size(void) const
        {
            return _recv_frame_size;
        }

        managed_recv_buffer::sptr get_recv_buff(double timeout)
        {
            managed_recv_buffer::sptr buff;
            if (_buff_queue.pop_with_timed_wait(buff, timeout)) {
                return buff;
            } else {
                return managed_recv_buffer::sptr();
            }
        }

        void push_recv_buff(managed_recv_buffer::sptr buff)
        {
            _buff_queue.push_with_wait(
                _buffers.at(_buffer_index++)->get_new(buff->cast<char*>(), buff->size()));
            _buffer_index %= _buffers.size();
        }

        size_t get_num_send_frames(void) const
        {
            return _num_send_frames;
        }

        size_t get_send_frame_size(void) const
        {
            return _send_frame_size;
        }

        managed_send_buffer::sptr get_send_buff(double timeout)
        {
            return _muxed_xport->base_xport()->get_send_buff(timeout);
        }

    private:
        const uint32_t _stream_num;
        muxed_zero_copy_if_impl::sptr _muxed_xport;
        const size_t _num_send_frames;
        const size_t _send_frame_size;
        const size_t _num_recv_frames;
        const size_t _recv_frame_size;
        bounded_buffer<managed_recv_buffer::sptr> _buff_queue;
        std::vector<boost::shared_ptr<stream_mrb>> _buffers;
        size_t _buffer_index;
    };

    inline zero_copy_if::sptr& base_xport()
    {
        return _base_xport;
    }

    void _update_queues()
    {
        // Run forever:
        // - Pull packets from the base transport
        // - Classify them
        // - Push them to the appropriate receive queue
        while (true) {
            { // Uninterruptable block of code
                boost::this_thread::disable_interruption interrupt_disabler;
                if (not _process_next_buffer()) {
                    // Be a good citizen and yield if no packet is processed
                    static const size_t MIN_DUR = 1;
                    boost::this_thread::sleep_for(boost::chrono::nanoseconds(MIN_DUR));
                    // We call sleep(MIN_DUR) above instead of yield() to ensure that we
                    // relinquish the current scheduler time slot.
                    // yield() is a hint to the scheduler to end the time
                    // slice early and schedule in another thread that is ready to run.
                    // However in most situations, there will be no other thread and
                    // this thread will continue to run which will rail a CPU core.
                    // We call sleep(MIN_DUR=1) instead which will sleep for a minimum
                    // time. Ideally we would like to use boost::chrono::.*seconds::min()
                    // but that is bound to 0, which causes the sleep_for call to be a
                    // no-op and thus useless to actually force a sleep.

                    //****************************************************************
                    // NOTE: This behavior makes this transport a poor choice for
                    //      low latency communication.
                    //****************************************************************
                }
            }
            // Check if the master thread has requested a shutdown
            if (boost::this_thread::interruption_requested())
                break;
        }
    }

    bool _process_next_buffer()
    {
        managed_recv_buffer::sptr buff = _base_xport->get_recv_buff(0.0);
        if (buff) {
            stream_impl::sptr stream;
            try {
                const uint32_t stream_num =
                    _classify(buff->cast<void*>(), _base_xport->get_recv_frame_size());
                {
                    // Hold the stream mutex long enough to pull a bounded buffer
                    // and lock it (increment its ref count).
                    boost::lock_guard<boost::mutex> lock(_mutex);
                    stream_map_t::iterator str_iter = _streams.find(stream_num);
                    if (str_iter != _streams.end()) {
                        stream = (*str_iter).second.lock();
                    }
                }
            } catch (std::exception&) {
                // If _classify throws we simply drop the frame
            }
            // Once a bounded buffer is acquired, we can rely on its
            // thread safety to serialize with the consumer.
            if (stream.get()) {
                stream->push_recv_buff(buff);
            } else {
                boost::lock_guard<boost::mutex> lock(_mutex);
                _num_dropped_frames++;
            }
            // We processed a packet, and there could be more coming
            // Don't yield in the next iteration.
            return true;
        } else {
            // The base transport is idle. Return false to let the
            // thread yield.
            return false;
        }
    }

    typedef std::map<uint32_t, stream_impl::wptr> stream_map_t;

    zero_copy_if::sptr _base_xport;
    stream_classifier_fn _classify;
    stream_map_t _streams;
    const size_t _max_num_streams;
    size_t _num_dropped_frames;
    boost::thread _recv_thread;
    boost::mutex _mutex;
};

muxed_zero_copy_if::sptr muxed_zero_copy_if::make(zero_copy_if::sptr base_xport,
    muxed_zero_copy_if::stream_classifier_fn classify_fn,
    size_t max_streams)
{
    return boost::make_shared<muxed_zero_copy_if_impl>(
        base_xport, classify_fn, max_streams);
}