/*
   Copyright (C) 2017
   Matthias P. Braendli, matthias.braendli@mpb.li
    http://opendigitalradio.org
 */
/*
   This file is part of ODR-DPD.
   ODR-DPD is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as
   published by the Free Software Foundation, either version 3 of the
   License, or (at your option) any later version.
   ODR-DPD is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.
   You should have received a copy of the GNU General Public License
   along with ODR-DPD.  If not, see .
 */
#include "OutputUHD.hpp"
#include "pointcloud.hpp"
#include "AlignSample.hpp"
#include "utils.hpp"
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
std::atomic running;
void sig_int_handler(int) {
    running = false;
}
static int set_realtime_prio(int prio)
{
    // Set thread priority to realtime
    const int policy = SCHED_RR;
    sched_param sp;
    sp.sched_priority = sched_get_priority_min(policy) + prio;
    int ret = pthread_setschedparam(pthread_self(), policy, &sp);
    return ret;
}
size_t read_samples_from_file(FILE* fd, std::vector& samples, size_t count)
{
    if (samples.size() < count) {
        MDEBUG("HAD TO RESIZE BUFFER!\n");
        samples.resize(count);
    }
    size_t num_read = fread(&samples.front(), sizeof(complexf), count, fd);
    if (num_read == 0) {
        rewind(fd);
        num_read = fread(&samples.front(), sizeof(complexf), count, fd);
    }
    return num_read;
}
AlignSample aligner;
PointCloud cloud(10000);
size_t do_receive(OutputUHD *output_uhd)
{
    std::vector samps(samps_per_buffer);
    double first_sample_time = 0;
    size_t total_received = 0;
    double last_print_time = 0;
    MDEBUG("Starting do_receive\n");
    while (running) {
        ssize_t received = output_uhd->Receive(&samps.front(), samps.size(), &first_sample_time);
        if (received > 0) {
            aligner.push_rx_samples(&samps.front(), received, first_sample_time);
            total_received += received;
            if (first_sample_time - last_print_time > 1) {
                //MDEBUG("Rx %zu samples at t=%f\n", received, first_sample_time);
                last_print_time = first_sample_time;
            }
        }
        else {
            // A receive error occurred that invalidates the RX timestamp
            MDEBUG("Reset aligner RX\n");
            aligner.reset_rx();
        }
    }
    MDEBUG("Leaving do_receive\n");
    return total_received;
}
const size_t correlation_length = 16 * 1024; // 8ms at 2048000
long user_delay = 0;
void push_to_point_cloud(size_t rx_delay)
{
    aligner.delay_rx_samples(rx_delay + user_delay);
    auto points = aligner.get_samples(correlation_length);
    if (points.first.size() > 0) {
        cloud.push_samples(points);
    }
}
size_t find_peak_correlation(size_t correlation_length)
{
    double max_norm = 0.0;
    size_t pos_max = 0;
    auto result = aligner.crosscorrelate(correlation_length);
    auto& xcs = result.correlation;
    // Find correlation peak
    for (size_t offset = 0; offset < xcs.size(); offset++) {
        complexf xc = xcs[offset];
        if (std::norm(xc) >= max_norm) {
            max_norm = std::norm(xc);
            pos_max = offset;
        }
    }
    char msg[512];
    snprintf(msg, 512, "Max correlation is %f at %fms (%zu), with RX %fdB and TX %fdB, RXtime %f, TXtime %f\n",
            std::sqrt(max_norm),
            (double)pos_max / (double)samplerate * 1000.0,
            pos_max,
            10*std::log(result.rx_power),
            10*std::log(result.tx_power),
            result.rx_timestamp,
            result.tx_timestamp);
    std::cerr << msg;
    std::this_thread::sleep_for(std::chrono::microseconds(1));
    // Eat much more than we correlate, because correlation is slow
    aligner.consume(204800);
    return pos_max;
}
void analyse_correlation()
{
    const size_t num_analyse = 10;
    std::vector max_positions(num_analyse);
    while (running) {
        for (size_t i = 0; running and i < num_analyse; i++) {
            if (aligner.ready(correlation_length)) {
                max_positions[i] = find_peak_correlation(correlation_length);
            }
            else {
                MDEBUG("Waiting for correlation\n");
                aligner.debug();
                std::this_thread::sleep_for(std::chrono::seconds(1));
            }
        }
        bool all_identical = true;
        double mean = std::accumulate(max_positions.begin(), max_positions.end(), 0.0) / (double)max_positions.size();
        for (size_t i = 0; i < num_analyse; i++) {
            if (std::fabs(max_positions[i] - mean) > 1) {
                all_identical = false;
                break;
            }
        }
        if (all_identical) {
            size_t delay_samples = max_positions[0];
            push_to_point_cloud(delay_samples);
        }
        else {
            MDEBUG("Not all delays identical\n");
        }
    }
}
int main(int argc, char **argv)
{
    double txgain = 0;
    double rxgain = 0;
    if (argc >= 3) {
        txgain = strtod(argv[2], nullptr);
        if (!(0 <= txgain and txgain < 80)) {
            MDEBUG("txgain wrong: %f\n", txgain);
            return -1;
        }
    }
    if (argc >= 4) {
        rxgain = strtod(argv[3], nullptr);
        if (!(0 <= rxgain and rxgain < 80)) {
            MDEBUG("rxgain wrong: %f\n", rxgain);
            return -1;
        }
    }
    if (argc < 2) {
        MDEBUG("Require input file or url\n");
        return -1;
    }
    set_realtime_prio(1);
    std::string uri = argv[1];
    zmq::context_t ctx;
    zmq::socket_t zmq_sock(ctx, ZMQ_SUB);
    FILE* fd = nullptr;
    if (uri == "test") { //{{{
        FILE* fd_rx = fopen("rx.test", "r");
        if (!fd_rx) {
            std::cerr << "fx_rx open error" << std::endl;
            abort();
        }
        FILE* fd_tx = fopen("tx.test", "r");
        if (!fd_tx) {
            std::cerr << "fx_tx open error" << std::endl;
            abort();
        }
        size_t num_rx_samples;
        size_t num_tx_samples;
        do {
            const size_t len = 64;
            std::vector rx_samples(len);
            std::vector tx_samples(len);
            num_rx_samples = fread(&rx_samples.front(), sizeof(complexf), len, fd_rx);
            num_tx_samples = fread(&tx_samples.front(), sizeof(complexf), len, fd_tx);
            aligner.push_rx_samples(&rx_samples.front(), num_rx_samples, 1);
            aligner.push_tx_samples(&tx_samples.front(), num_tx_samples, 1);
            std::cerr << ".";
        } while (num_rx_samples and num_tx_samples);
        std::cerr << std::endl;
        aligner.debug();
        const size_t correlation_length = 16 * 1024;
        double max_norm = 0.0;
        size_t pos_max = 0;
        while (aligner.ready(correlation_length)) {
            auto result = aligner.crosscorrelate(correlation_length);
            auto& xcs = result.correlation;
            for (size_t offset = 0; offset < xcs.size(); offset++) {
                complexf& xc = xcs[offset];
                if (std::norm(xc) >= max_norm) {
                    max_norm = std::norm(xc);
                    pos_max = offset;
                }
            }
            MDEBUG("Max correlation is %f at %fms (%zu), with RX %fdB and TX %fdB, RXtime %f, TXtime %f\n",
                    std::sqrt(max_norm),
                    (double)pos_max / (double)samplerate * 1000.0,
                    pos_max,
                    10*std::log(result.rx_power),
                    10*std::log(result.tx_power),
                    result.rx_timestamp,
                    result.tx_timestamp);
            aligner.consume(correlation_length / 2);
        }
        return 0;
    } // }}}
    else if (uri.find("tcp://") != 0) {
        fd = fopen(uri.c_str(), "rb");
        if (!fd) {
            MDEBUG("Could not open file\n");
            return -1;
        }
    }
    else {
        zmq_sock.connect(uri.c_str());
        zmq_sock.setsockopt(ZMQ_SUBSCRIBE, NULL, 0);
    }
    OutputUHD output_uhd(txgain, rxgain, samplerate);
    size_t samps_read = 0;
    size_t total_samps_read = samps_read;
    double last_print_time = 0;
    size_t sent = 0;
    std::signal(SIGINT, &sig_int_handler);
    running = true;
    std::thread receive_thread(do_receive, &output_uhd);
    std::thread correlator_thread(analyse_correlation);
    do {
        const double first_sample_time = 4.0;
        const double sample_time = first_sample_time + (double)total_samps_read / (double)samplerate;
        if (fd) {
            std::vector input_samples(samps_per_buffer);
            samps_read = read_samples_from_file(fd, input_samples, samps_per_buffer);
            sent = output_uhd.Transmit(&input_samples.front(), samps_read, sample_time);
            aligner.push_tx_samples(&input_samples.front(), samps_read, sample_time);
        }
        else {
            zmq::message_t msg;
            if (not zmq_sock.recv(&msg)) {
                MDEBUG("zmq recv error\n");
                return -1;
            }
            if (msg.size() % sizeof(complexf) != 0) {
                MDEBUG("Received incomplete size %zu\n", msg.size());
                return -1;
            }
            samps_read = msg.size() / sizeof(complexf);
            sent = output_uhd.Transmit((complexf*)msg.data(), samps_read, sample_time);
            aligner.push_tx_samples((complexf*)msg.data(), samps_read, sample_time);
        }
        if (sample_time - last_print_time > 1) {
            //MDEBUG("Tx %zu samples at t=%f\n", samps_read, sample_time);
            last_print_time = sample_time;
        }
        total_samps_read += samps_read;
        try {
            std::string keyname = cloud.handle_event();
            if (keyname == "l") {
                user_delay += 1;
                std::cerr << "User delay: " << user_delay << std::endl;
            }
            else if (keyname == "e") {
                user_delay -= 1;
                std::cerr << "User delay: " << user_delay << std::endl;
            }
            else if (keyname == "z") {
                rxgain -= 1;
                output_uhd.SetRxGain(rxgain);
            }
            else if (keyname == "a") {
                rxgain += 1;
                output_uhd.SetRxGain(rxgain);
            }
            else if (keyname == "x") {
                txgain -= 1;
                output_uhd.SetTxGain(txgain);
            }
            else if (keyname == "s") {
                txgain += 1;
                output_uhd.SetTxGain(txgain);
            }
            else if (not keyname.empty()) {
                std::cerr << "Press L for later, E for earlier, Z/A to decrease/increase RX gain, X/S for TX gain" << std::endl;
            }
        }
        catch (sdl_quit &e) {
            running = false;
        }
        cloud.draw();
    }
    while (samps_read and sent and running);
    MDEBUG("Leaving main loop with running=%d\n", running ? 1 : 0);
    running = false;
    receive_thread.join();
    correlator_thread.join();
}