//
// Copyright 2015-2016 Ettus Research LLC
// Copyright 2018 Ettus Research, a National Instruments Company
//
// SPDX-License-Identifier: GPL-3.0-or-later
//

#include "fw_common.h"
#include "usrp2_iface.hpp"
#include "usrp2_impl.hpp"
#include <uhd/config.hpp>
#include <uhd/exception.hpp>
#include <uhd/image_loader.hpp>
#include <uhd/transport/if_addrs.hpp>
#include <uhd/transport/udp_simple.hpp>
#include <uhd/types/dict.hpp>
#include <uhd/utils/byteswap.hpp>
#include <uhd/utils/paths.hpp>
#include <uhd/utils/static.hpp>
#include <boost/algorithm/string/erase.hpp>
#include <boost/asio/ip/address_v4.hpp>
#include <boost/assign.hpp>
#include <boost/filesystem.hpp>
#include <boost/format.hpp>
#include <boost/thread.hpp>
#include <cstring>
#include <fstream>
#include <iostream>

typedef boost::asio::ip::address_v4 ip_v4;

namespace fs = boost::filesystem;
using namespace boost::algorithm;

using namespace uhd;
using namespace uhd::usrp;
using namespace uhd::transport;

/*
 * Constants
 */

#define N200_FLASH_DATA_PACKET_SIZE 256
#define N200_UDP_FW_UPDATE_PORT 49154
#define UDP_TIMEOUT 0.5

#define N200_FW_MAX_SIZE_BYTES 31744
#define N200_PROD_FW_IMAGE_ADDR 0x00300000
#define N200_SAFE_FW_IMAGE_ADDR 0x003F0000

#define N200_FPGA_MAX_SIZE_BYTES 1572864
#define N200_PROD_FPGA_IMAGE_ADDR 0x00180000
#define N200_SAFE_FPGA_IMAGE_ADDR 0x00000000

/*
 * Packet codes
 */
typedef enum {
    UNKNOWN = ' ',

    N200_QUERY = 'a',
    N200_ACK   = 'A',

    GET_FLASH_INFO_CMD = 'f',
    GET_FLASH_INFO_ACK = 'F',

    ERASE_FLASH_CMD = 'e',
    ERASE_FLASH_ACK = 'E',

    CHECK_ERASING_DONE_CMD = 'd',
    DONE_ERASING_ACK       = 'D',
    NOT_DONE_ERASING_ACK   = 'B',

    WRITE_FLASH_CMD = 'w',
    WRITE_FLASH_ACK = 'W',

    READ_FLASH_CMD = 'r',
    READ_FLASH_ACK = 'R',

    RESET_CMD = 's',
    RESET_ACK = 'S',

    GET_HW_REV_CMD = 'v',
    GET_HW_REV_ACK = 'V',
} n200_fw_update_id_t;

/*
 * Mapping revision numbers to names
 */
static const uhd::dict<uint32_t, std::string> n200_filename_map =
    boost::assign::map_list_of(
        0, "n2xx") // Is an N-Series, but the EEPROM value is invalid
    (0xa, "n200_r3")(0x100a, "n200_r4")(0x10a, "n210_r3")(0x110a, "n210_r4");

/*
 * Packet structure
 */
typedef struct
{
    uint32_t proto_ver;
    uint32_t id;
    uint32_t seq;
    union {
        uint32_t ip_addr;
        uint32_t hw_rev;
        struct
        {
            uint32_t flash_addr;
            uint32_t length;
            uint8_t data[256];
        } flash_args;
        struct
        {
            uint32_t sector_size_bytes;
            uint32_t memory_size_bytes;
        } flash_info_args;
    } data;
} n200_fw_update_data_t;

/*
 * N-Series burn session
 */
typedef struct
{
    bool fw;
    bool overwrite_safe;
    bool reset;
    uhd::device_addr_t dev_addr;
    std::string burn_type;
    std::string filepath;
    uint8_t data_in[udp_simple::mtu];
    uint32_t size;
    uint32_t max_size;
    uint32_t flash_addr;
    udp_simple::sptr xport;
} n200_session_t;

/***********************************************************************
 * uhd::image_loader functionality
 **********************************************************************/

static void print_usrp2_error(const image_loader::image_loader_args_t& image_loader_args)
{
#ifdef UHD_PLATFORM_WIN32
    std::string usrp2_card_burner_gui = "\"";
    const std::string nl              = " ^\n    ";
#else
    std::string usrp2_card_burner_gui = "sudo \"";
    const std::string nl              = " \\\n    ";
#endif

    usrp2_card_burner_gui += find_utility("usrp2_card_burner_gui.py");
    usrp2_card_burner_gui += "\"";

    if (image_loader_args.load_firmware) {
        usrp2_card_burner_gui += str(boost::format("%s--fw=\"%s\"") % nl
                                     % ((image_loader_args.firmware_path.empty())
                                             ? find_image_path("usrp2_fw.bin")
                                             : image_loader_args.firmware_path));
    }
    if (image_loader_args.load_fpga) {
        usrp2_card_burner_gui += str(
            boost::format("%s--fpga=\"%s\"") % nl
            % ((image_loader_args.fpga_path.empty()) ? find_image_path("usrp2_fpga.bin")
                                                     : image_loader_args.fpga_path));
    }

    throw uhd::runtime_error(str(
        boost::format(
            "The specified device is a USRP2, which is not supported by this utility.\n"
            "Instead, plug the device's SD card into your machine and run this "
            "command:\n\n"
            "%s")
        % usrp2_card_burner_gui));
}

/*
 * Ethernet communication functions
 */
static UHD_INLINE size_t n200_send_and_recv(udp_simple::sptr xport,
    n200_fw_update_id_t pkt_code,
    n200_fw_update_data_t* pkt_out,
    uint8_t* data)
{
    pkt_out->proto_ver = htonx<uint32_t>(USRP2_FW_COMPAT_NUM);
    pkt_out->id        = htonx<uint32_t>(pkt_code);
    xport->send(boost::asio::buffer(pkt_out, sizeof(*pkt_out)));
    return xport->recv(boost::asio::buffer(data, udp_simple::mtu), UDP_TIMEOUT);
}

static UHD_INLINE bool n200_response_matches(
    const n200_fw_update_data_t* pkt_in, n200_fw_update_id_t pkt_code, size_t len)
{
    return (len > offsetof(n200_fw_update_data_t, data)
            and ntohl(pkt_in->id) == (unsigned)pkt_code);
}

static uhd::device_addr_t n200_find(
    const image_loader::image_loader_args_t& image_loader_args)
{
    bool user_specified = image_loader_args.args.has_key("addr")
                          or image_loader_args.args.has_key("serial")
                          or image_loader_args.args.has_key("name");

    uhd::device_addrs_t found = usrp2_find(image_loader_args.args);

    if (!found.empty()) {
        uhd::device_addrs_t n200_found;
        udp_simple::sptr rev_xport;
        n200_fw_update_data_t pkt_out;
        uint8_t data_in[udp_simple::mtu];
        const n200_fw_update_data_t* pkt_in =
            reinterpret_cast<const n200_fw_update_data_t*>(data_in);
        size_t len = 0;

        /*
         * Filter out any USRP2 devices by sending a query over the
         * UDP update port. Only N-Series devices will respond to
         * this query. If the user supplied specific arguments that
         * led to a USRP2, throw an error.
         */
        for (const uhd::device_addr_t& dev : found) {
            rev_xport = udp_simple::make_connected(
                dev.get("addr"), BOOST_STRINGIZE(N200_UDP_FW_UPDATE_PORT));

            len = n200_send_and_recv(rev_xport, GET_HW_REV_CMD, &pkt_out, data_in);
            if (n200_response_matches(pkt_in, GET_HW_REV_ACK, len)) {
                uint32_t rev       = ntohl(pkt_in->data.hw_rev);
                std::string hw_rev = n200_filename_map.get(rev, "n2xx");

                n200_found.push_back(dev);
                n200_found[n200_found.size() - 1]["hw_rev"] = hw_rev;
            } else if (len > offsetof(n200_fw_update_data_t, data)
                       and ntohl(pkt_in->id) != GET_HW_REV_ACK) {
                throw uhd::runtime_error(
                    str(boost::format("Received invalid reply %d from device.")
                        % ntohl(pkt_in->id)));
            } else if (user_specified) {
                // At this point, we haven't received any response, so assume it's a USRP2
                print_usrp2_error(image_loader_args);
            }
        }

        // At this point, we should have a single N-Series device
        if (n200_found.size() == 1) {
            return n200_found[0];
        } else if (n200_found.size() > 1) {
            std::string err_msg =
                "Could not resolve given args to a single N-Series device.\n"
                "Applicable devices:\n";

            for (const uhd::device_addr_t& dev : n200_found) {
                err_msg += str(boost::format("* %s (addr=%s)\n") % dev.get("hw_rev")
                               % dev.get("addr"));
            }

            err_msg += "\nSpecify one of these devices with the given args to load an "
                       "image onto it.";

            throw uhd::runtime_error(err_msg);
        }
    }

    return uhd::device_addr_t();
}

/*
 * Validate and read firmware image
 */
static void n200_validate_firmware_image(n200_session_t& session)
{
    if (not fs::exists(session.filepath)) {
        throw uhd::runtime_error(str(
            boost::format("Could not find image at path \"%s\".") % session.filepath));
    }

    session.size     = fs::file_size(session.filepath);
    session.max_size = N200_FW_MAX_SIZE_BYTES;

    if (session.size > session.max_size) {
        throw uhd::runtime_error(
            str(boost::format("The specified firmware image is too large: %d vs. %d")
                % session.size % session.max_size));
    }

    // File must have proper header
    std::ifstream image_file(session.filepath.c_str(), std::ios::binary);
    uint8_t test_bytes[4];
    image_file.seekg(0, std::ios::beg);
    image_file.read((char*)test_bytes, 4);
    image_file.close();
    for (int i = 0; i < 4; i++)
        if (test_bytes[i] != 11) {
            throw uhd::runtime_error(str(
                boost::format("The file at path \"%s\" is not a valid firmware image.")
                % session.filepath));
        }
}

/*
 * Validate and validate FPGA image
 */
static void n200_validate_fpga_image(n200_session_t& session)
{
    if (not fs::exists(session.filepath)) {
        throw uhd::runtime_error(str(
            boost::format("Could not find image at path \"%s\".") % session.filepath));
    }

    session.size     = fs::file_size(session.filepath);
    session.max_size = N200_FPGA_MAX_SIZE_BYTES;

    if (session.size > session.max_size) {
        throw uhd::runtime_error(
            str(boost::format("The specified FPGA image is too large: %d vs. %d")
                % session.size % session.max_size));
    }

    // File must have proper header
    std::ifstream image_file(session.filepath.c_str(), std::ios::binary);
    uint8_t test_bytes[63];
    image_file.seekg(0, std::ios::beg);
    image_file.read((char*)test_bytes, 63);
    bool is_good = false;
    for (int i = 0; i < 62; i++) {
        if (test_bytes[i] == 255)
            continue;
        else if (test_bytes[i] == 170 and test_bytes[i + 1] == 153) {
            is_good = true;
            break;
        }
    }
    image_file.close();
    if (not is_good) {
        throw uhd::runtime_error(
            str(boost::format("The file at path \"%s\" is not a valid FPGA image.")
                % session.filepath));
    }
}

/*
 * Set up a session for burning an N-Series image. This session info
 * will be passed into the erase, burn, and verify functions.
 */
static void n200_setup_session(n200_session_t& session,
    const image_loader::image_loader_args_t& image_loader_args,
    bool fw)
{
    session.fw    = fw;
    session.reset = image_loader_args.args.has_key("reset");

    /*
     * If no filepath is given, attempt to determine the default image by
     * querying the device for its revision. If the device has a corrupt
     * EEPROM or is otherwise unable to provide its revision, this is
     * impossible, and the user must manually provide a firmware file.
     */
    if ((session.fw and image_loader_args.firmware_path.empty())
        or image_loader_args.fpga_path.empty()) {
        if (session.dev_addr["hw_rev"] == "n2xx") {
            throw uhd::runtime_error("This device's revision cannot be determined. "
                                     "You must manually specify a filepath.");
        } else {
            session.filepath =
                session.fw ? find_image_path(
                                 str(boost::format("usrp_%s_fw.bin")
                                     % erase_tail_copy(session.dev_addr["hw_rev"], 3)))
                           : find_image_path(str(boost::format("usrp_%s_fpga.bin")
                                                 % session.dev_addr["hw_rev"]));
        }
    } else {
        session.filepath = session.fw ? image_loader_args.firmware_path
                                      : image_loader_args.fpga_path;
    }
    if (session.fw)
        n200_validate_firmware_image(session);
    else
        n200_validate_fpga_image(session);

    session.overwrite_safe = image_loader_args.args.has_key("overwrite-safe");
    if (session.overwrite_safe) {
        session.flash_addr = session.fw ? N200_SAFE_FW_IMAGE_ADDR
                                        : N200_SAFE_FPGA_IMAGE_ADDR;
        session.burn_type = session.fw ? "firmware safe" : "FPGA safe";
    } else {
        session.flash_addr = session.fw ? N200_PROD_FW_IMAGE_ADDR
                                        : N200_PROD_FPGA_IMAGE_ADDR;
        session.burn_type = session.fw ? "firmware" : "FPGA";
    }

    session.xport = udp_simple::make_connected(
        session.dev_addr["addr"], BOOST_STRINGIZE(N200_UDP_FW_UPDATE_PORT));
}

static void n200_erase_image(n200_session_t& session)
{
    // UDP receive buffer
    n200_fw_update_data_t pkt_out;
    const n200_fw_update_data_t* pkt_in =
        reinterpret_cast<const n200_fw_update_data_t*>(session.data_in);

    // Setting up UDP packet
    pkt_out.data.flash_args.flash_addr = htonx<uint32_t>(session.flash_addr);
    pkt_out.data.flash_args.length     = htonx<uint32_t>(session.size);

    // Begin erasing
    size_t len =
        n200_send_and_recv(session.xport, ERASE_FLASH_CMD, &pkt_out, session.data_in);
    if (n200_response_matches(pkt_in, ERASE_FLASH_ACK, len)) {
        std::cout << boost::format("-- Erasing %s image...") % session.burn_type
                  << std::flush;
    } else if (len < offsetof(n200_fw_update_data_t, data)) {
        std::cout << "failed." << std::endl;
        throw uhd::runtime_error("Timed out waiting for reply from device.");
    } else if (ntohl(pkt_in->id) != ERASE_FLASH_ACK) {
        std::cout << "failed." << std::endl;
        throw uhd::runtime_error(
            str(boost::format("Received invalid reply %d from device.\n")
                % ntohl(pkt_in->id)));
    } else {
        std::cout << "failed." << std::endl;
        throw uhd::runtime_error("Did not receive response from device.");
    }

    // Check for erase completion
    while (true) {
        len = n200_send_and_recv(
            session.xport, CHECK_ERASING_DONE_CMD, &pkt_out, session.data_in);
        if (n200_response_matches(pkt_in, DONE_ERASING_ACK, len)) {
            std::cout << "successful." << std::endl;
            break;
        } else if (len < offsetof(n200_fw_update_data_t, data)) {
            std::cout << "failed." << std::endl;
            throw uhd::runtime_error("Timed out waiting for reply from device.");
        } else if (ntohl(pkt_in->id) != NOT_DONE_ERASING_ACK) {
            std::cout << "failed." << std::endl;
            throw uhd::runtime_error(
                str(boost::format("Received invalid reply %d from device.\n")
                    % ntohl(pkt_in->id)));
        }
    }
}

static void n200_write_image(n200_session_t& session)
{
    // UDP receive buffer
    n200_fw_update_data_t pkt_out;
    const n200_fw_update_data_t* pkt_in =
        reinterpret_cast<const n200_fw_update_data_t*>(session.data_in);
    size_t len = 0;

    // Write image
    std::ifstream image(session.filepath.c_str(), std::ios::binary);
    uint32_t current_addr          = session.flash_addr;
    pkt_out.data.flash_args.length = htonx<uint32_t>(N200_FLASH_DATA_PACKET_SIZE);
    for (size_t i = 0; i < ((session.size / N200_FLASH_DATA_PACKET_SIZE) + 1); i++) {
        pkt_out.data.flash_args.flash_addr = htonx<uint32_t>(current_addr);
        memset(pkt_out.data.flash_args.data, 0x0, N200_FLASH_DATA_PACKET_SIZE);
        image.read((char*)pkt_out.data.flash_args.data, N200_FLASH_DATA_PACKET_SIZE);

        len =
            n200_send_and_recv(session.xport, WRITE_FLASH_CMD, &pkt_out, session.data_in);
        if (n200_response_matches(pkt_in, WRITE_FLASH_ACK, len)) {
            std::cout << boost::format("\r-- Writing %s image (%d%%)") % session.burn_type
                             % int((double(current_addr - session.flash_addr)
                                       / double(session.size))
                                   * 100)
                      << std::flush;
        } else if (len < offsetof(n200_fw_update_data_t, data)) {
            image.close();
            std::cout << boost::format("\r--Writing %s image..failed at %d%%.")
                             % session.burn_type
                             % int((double(current_addr - session.flash_addr)
                                       / double(session.size))
                                   * 100)
                      << std::endl;
            throw uhd::runtime_error("Timed out waiting for reply from device.");
        } else if (ntohl(pkt_in->id) != WRITE_FLASH_ACK) {
            image.close();
            std::cout << boost::format("\r--Writing %s image..failed at %d%%.")
                             % session.burn_type
                             % int((double(current_addr - session.flash_addr)
                                       / double(session.size))
                                   * 100)
                      << std::endl;
            throw uhd::runtime_error(
                str(boost::format("Received invalid reply %d from device.\n")
                    % ntohl(pkt_in->id)));
        }

        current_addr += N200_FLASH_DATA_PACKET_SIZE;
    }
    std::cout << boost::format("\r-- Writing %s image...successful.") % session.burn_type
              << std::endl;

    image.close();
}

static void n200_verify_image(n200_session_t& session)
{
    // UDP receive buffer
    n200_fw_update_data_t pkt_out;
    const n200_fw_update_data_t* pkt_in =
        reinterpret_cast<const n200_fw_update_data_t*>(session.data_in);
    size_t len = 0;

    // Read and verify image
    std::ifstream image(session.filepath.c_str(), std::ios::binary);
    uint8_t image_part[N200_FLASH_DATA_PACKET_SIZE];
    uint32_t current_addr          = session.flash_addr;
    pkt_out.data.flash_args.length = htonx<uint32_t>(N200_FLASH_DATA_PACKET_SIZE);
    uint16_t cmp_len               = 0;
    for (size_t i = 0; i < ((session.size / N200_FLASH_DATA_PACKET_SIZE) + 1); i++) {
        memset(image_part, 0x0, N200_FLASH_DATA_PACKET_SIZE);
        memset((void*)pkt_in->data.flash_args.data, 0x0, N200_FLASH_DATA_PACKET_SIZE);

        pkt_out.data.flash_args.flash_addr = htonx<uint32_t>(current_addr);
        image.read((char*)image_part, N200_FLASH_DATA_PACKET_SIZE);
        cmp_len = image.gcount();

        len =
            n200_send_and_recv(session.xport, READ_FLASH_CMD, &pkt_out, session.data_in);
        if (n200_response_matches(pkt_in, READ_FLASH_ACK, len)) {
            std::cout << boost::format("\r-- Verifying %s image (%d%%)")
                             % session.burn_type
                             % int((double(current_addr - session.flash_addr)
                                       / double(session.size))
                                   * 100)
                      << std::flush;

            if (memcmp(image_part, pkt_in->data.flash_args.data, cmp_len)) {
                std::cout << boost::format("\r-- Verifying %s image...failed at %d%%.")
                                 % session.burn_type
                                 % int((double(current_addr - session.flash_addr)
                                           / double(session.size))
                                       * 100)
                          << std::endl;
                throw uhd::runtime_error(
                    str(boost::format("Failed to verify %s image.") % session.burn_type));
            }
        } else if (len < offsetof(n200_fw_update_data_t, data)) {
            image.close();
            std::cout << boost::format("\r-- Verifying %s image...failed at %d%%.")
                             % session.burn_type
                             % int((double(current_addr - session.flash_addr)
                                       / double(session.size))
                                   * 100)
                      << std::endl;
            throw uhd::runtime_error("Timed out waiting for reply from device.");
        } else if (ntohl(pkt_in->id) != READ_FLASH_ACK) {
            image.close();
            std::cout << boost::format("\r-- Verifying %s image...failed at %d%%.")
                             % session.burn_type
                             % int((double(current_addr - session.flash_addr)
                                       / double(session.size))
                                   * 100)
                      << std::endl;
            throw uhd::runtime_error(
                str(boost::format("Received invalid reply %d from device.\n")
                    % ntohl(pkt_in->id)));
        }

        current_addr += N200_FLASH_DATA_PACKET_SIZE;
    }
    std::cout << boost::format("\r-- Verifying %s image...successful.")
                     % session.burn_type
              << std::endl;

    image.close();
}

static void n200_reset(n200_session_t& session)
{
    // UDP receive buffer
    n200_fw_update_data_t pkt_out;

    // There should be no response
    std::cout << "-- Resetting device..." << std::flush;
    size_t len = n200_send_and_recv(session.xport, RESET_CMD, &pkt_out, session.data_in);
    if (len > 0) {
        std::cout << "failed." << std::endl;
        throw uhd::runtime_error("Failed to reset N200.");
    }
    std::cout << "successful." << std::endl;
}

// n210_r4 -> N210 r4
static std::string nice_name(const std::string& fw_rev)
{
    std::string ret = fw_rev;
    ret[0]          = ::toupper(ret[0]);

    size_t pos = 0;
    if ((pos = fw_rev.find("_")) != std::string::npos) {
        ret[pos] = ' ';
    }

    return ret;
}

static bool n200_image_loader(const image_loader::image_loader_args_t& image_loader_args)
{
    if (!image_loader_args.load_firmware and !image_loader_args.load_fpga) {
        return false;
    }

    // See if any N2x0 with the given args is found
    // This will throw if specific args lead to a USRP2
    n200_session_t session;
    session.dev_addr = n200_find(image_loader_args);
    if (session.dev_addr.size() == 0) {
        return false;
    }

    std::cout << boost::format("Unit: USRP %s (%s, %s)")
                     % nice_name(session.dev_addr.get("hw_rev"))
                     % session.dev_addr.get("serial") % session.dev_addr.get("addr")
              << std::endl;

    if (image_loader_args.load_firmware) {
        n200_setup_session(session, image_loader_args, true);

        std::cout << "Firmware image: " << session.filepath << std::endl;

        n200_erase_image(session);
        n200_write_image(session);
        n200_verify_image(session);
        if (session.reset and !image_loader_args.load_fpga) {
            n200_reset(session);
        }
    }
    if (image_loader_args.load_fpga) {
        n200_setup_session(session, image_loader_args, false);

        std::cout << "FPGA image: " << session.filepath << std::endl;

        n200_erase_image(session);
        n200_write_image(session);
        n200_verify_image(session);
        if (session.reset) {
            n200_reset(session);
        }
    }

    return true;
}

UHD_STATIC_BLOCK(register_n200_image_loader)
{
    std::string recovery_instructions =
        "Aborting. Your USRP-N Series unit will likely be unusable.\n"
        "Refer to http://files.ettus.com/manual/page_usrp2.html#usrp2_loadflash_brick\n"
        "for details on restoring your device.";

    image_loader::register_image_loader(
        "usrp2", n200_image_loader, recovery_instructions);
}