#!@RUNTIME_PYTHON_EXECUTABLE@
#
# Copyright 2018 Ettus Research, a National Instruments Company
#
# SPDX-License-Identifier: GPL-3.0-or-later
#
"""
Download image files required for USRPs
"""
from __future__ import print_function
import argparse
import hashlib
import json
import math
import os
import re
import shutil
import sys
import tempfile
import zipfile
import platform
# For all the non-core-library imports, we will be extra paranoid and be very
# nice with error messages so that everyone understands what's up.
try:
    from urllib.parse import urljoin  # Python 3
except ImportError:
    from urlparse import urljoin      # Python 2
try:
    from builtins import input
except ImportError:
    input = raw_input
try:
    from six import iteritems
except ImportError:
    sys.stdout.write(
        "[ERROR] Missing module 'six'! Please install it, e.g., by "
        "running 'pip install six' or any other tool that can install "
        "Python modules.\n")
    if platform.system() == 'Windows':
        input('Hit Enter to continue.')
    exit(0)
try:
    import requests
except ImportError:
    sys.stdout.write(
        "[ERROR] Missing module 'requests'! Please install it, e.g., by "
        "running 'pip install requests' or any other tool that can install "
        "Python modules.\n")
    if platform.system() == 'Windows':
        input('Hit Enter to continue.')
    exit(0)



_DEFAULT_TARGET_REGEX     = "(fpga|fw|windrv)_default"
_BASE_DIR_STRUCTURE_PARTS = ["share", "uhd", "images"]
_DEFAULT_INSTALL_PATH     = os.path.join("@CMAKE_INSTALL_PREFIX@", *_BASE_DIR_STRUCTURE_PARTS)
_DEFAULT_BASE_URL         = "http://files.ettus.com/binaries/cache/"
_INVENTORY_FILENAME       = "inventory.json"
_CONTACT                  = "support@ettus.com"
_DEFAULT_BUFFER_SIZE      = 8192
_DEFAULT_DOWNLOAD_LIMIT   = 100 * 1024 * 1024 # Bytes
_ARCHIVE_DEFAULT_TYPE     = "zip"
_UHD_VERSION              = "@UHD_VERSION@"
# Note: _MANIFEST_CONTENTS are placed at the bottom of this file for aesthetic reasons
_LOG_LEVELS = {"TRACE": 1,
               "DEBUG": 2,
               "INFO": 3,
               "WARN": 4,
               "ERROR": 5}
_LOG_LEVEL = _LOG_LEVELS["INFO"]
_YES = False
_PROXIES = {}


def log(level, message):
    """Logging function"""
    message_log_level = _LOG_LEVELS.get(level, 0)
    if message_log_level >= _LOG_LEVEL:
        sys.stderr.write(
            "[{level}] {message}\n".format(level=level, message=message))


def parse_args():
    """
    Setup argument parser and parse. Also does some sanity checks and sets some
    global variables we want to use.
    """
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-t', '--types', action='append',
                        help="RegEx to select image sets from the manifest file.")
    parser.add_argument('-i', '--install-location',
                        default=None,
                        help="Set custom install location for images")
    parser.add_argument('-m', '--manifest-location', type=str, default="",
                        help="Set custom location for the manifest file")
    parser.add_argument('-I', '--inventory-location', type=str, default="",
                        help="Set custom location for the inventory file")
    parser.add_argument('-l', '--list-targets', action="store_true", default=False,
                        help="Print targets in the manifest file to stdout, and exit.\n"
                        "To get relative paths only, specify an empty base URL (-b '').")
    parser.add_argument('--url-only', action="store_true", default=False,
                        help="With -l, only print the URLs, nothing else.")
    parser.add_argument("--buffer-size", type=int, default=_DEFAULT_BUFFER_SIZE,
                        help="Set download buffer size")
    parser.add_argument("--download-limit", type=int, default=_DEFAULT_DOWNLOAD_LIMIT,
                        help="Set threshold for download limits. Any download "
                             "larger than this will require approval, either "
                             "interactively, or by providing --yes.")
    parser.add_argument("--http-proxy", type=str,
                        help="Specify HTTP proxy in the format "
                        "http://user:pass@1.2.3.4:port\n"
                        "If this this option is not given, the environment "
                        "variable HTTP_PROXY can also be used to specify a proxy.")
    parser.add_argument("-b", "--base-url", type=str, default=_DEFAULT_BASE_URL,
                        help="Set base URL for images download location")
    parser.add_argument("-k", "--keep", action="store_true", default=False,
                        help="Keep the downloaded images archives in the image directory")
    parser.add_argument("-T", "--test", action="store_true", default=False,
                        help="Verify the downloaded archives before extracting them")
    parser.add_argument("-y", "--yes", action="store_true", default=False,
                        help="Answer all questions with 'yes' (for scripting purposes).")
    parser.add_argument("-n", "--dry-run", action="store_true", default=False,
                        help="Print selected target without actually downloading them.")
    parser.add_argument("--refetch", action="store_true", default=False,
                        help="Ignore the inventory file and download all images.")
    parser.add_argument('-V', '--version', action='version', version=_UHD_VERSION)
    parser.add_argument('-q', '--quiet', action='count', default=0,
                        help="Decrease verbosity level")
    parser.add_argument('-v', '--verbose', action='count', default=0,
                        help="Increase verbosity level")
    # Some sanitation that's easier to handle outside of the argparse framework:
    args = parser.parse_args()
    if not args.base_url.endswith('/') and args.base_url != "":
        args.base_url += '/'
    if args.yes:
        global _YES
        _YES = True
    if args.http_proxy:
        global _PROXIES
        _PROXIES['http'] = args.http_proxy
    # Set the verbosity
    global _LOG_LEVEL
    log("TRACE", "Default log level: {}".format(_LOG_LEVEL))
    _LOG_LEVEL = _LOG_LEVEL - args.verbose + args.quiet
    return args


def get_images_dir(args):
    """
    Figure out where to store the images.
    """
    if args.install_location:
        return args.install_location
    if os.environ.get("UHD_IMAGES_DIR"):
        log("DEBUG",
            "UHD_IMAGES_DIR environment variable is set, using to set "
            "install location.")
        return os.environ.get("UHD_IMAGES_DIR")
    return _DEFAULT_INSTALL_PATH


def ask_permission(question, default_no=True):
    """
    Ask the question, and have the user type y or n on the keyboard. If the
    global variable _YES is true, this always returns True without asking the
    question. Otherwise, return True if the answer is 'yes', 'y', or something
    similar.
    """
    if _YES:
        log("DEBUG", "Assuming the answer is 'yes' for this question: " +
            question)
        return True
    postfix = "[y/N]" if default_no else "[Y/n]"
    answer = input(question + " " + postfix)
    if answer and answer[0].lower() == 'y':
        return True
    return False


class TemporaryDirectory:
    """Class to create a temporary directory"""
    def __enter__(self):
        try:
            self.name = tempfile.mkdtemp()
            return self.name
        except Exception as ex:
            log("ERROR", "Failed to create a temporary directory (%s)" % ex)
            raise ex

    # Can return 'True' to suppress incoming exception
    def __exit__(self, exc_type, exc_value, traceback):
        try:
            shutil.rmtree(self.name)
            log("TRACE", "Temp directory deleted.")
        except Exception as ex:
            log("ERROR", "Could not delete temporary directory: %s (%s)" % (self.name, ex))
        return exc_type is None


def get_manifest_raw(args):
    """
    Return the raw content of the manifest (i.e. the text file). It
    needs to be parsed to be of any practical use.
    """
    # If we're given a path to a manifest file, use it
    if os.path.exists(args.manifest_location):
        manifest_fn = args.manifest_location
        log("INFO", "Using manifest file at location: {}".format(manifest_fn))
        with open(manifest_fn, 'r') as manifest_file:
            manifest_raw = manifest_file.read()
    # Otherwise, use the CMake Magic manifest
    else:
        manifest_raw = _MANIFEST_CONTENTS
    log("TRACE", "Raw manifest contents: {}".format(manifest_raw))
    return manifest_raw


def parse_manifest(manifest_contents):
    """Parse the manifest file, returns a dictionary of potential targets"""
    manifest = {}
    for line in manifest_contents.split('\n'):
        line_unpacked = line.split()
        try:
            # Check that the line isn't empty or a comment
            if not line_unpacked or line.strip().startswith('#'):
                continue

            target, repo_hash, url, sha256_hash = line_unpacked
            manifest[target] = {"repo_hash": repo_hash,
                                "url": url,
                                "sha256_hash": sha256_hash,
                               }
        except ValueError:
            log("WARN", "Warning: Invalid line in manifest file:\n"
                "         {}".format(line))
            continue
    return manifest


def parse_inventory(inventory_fn):
    """Parse the inventory file, returns a dictionary of installed files"""
    try:
        if not os.path.exists(inventory_fn):
            log("INFO", "No inventory file found at {}. Creating an empty one.".format(inventory_fn))
            return {}
        with open(inventory_fn, 'r') as inventory_file:
            # TODO: verify the contents??
            return json.load(inventory_file)
    except Exception as ex:
        log("WARN", "Error parsing the inventory file. Assuming an empty inventory: {}".format(ex))
        return {}


def write_inventory(inventory, inventory_fn):
    """Writes the inventory to file"""
    try:
        with open(inventory_fn, 'w') as inventory_file:
            json.dump(inventory, inventory_file)
            return True
    except Exception as ex:
        log("ERROR", "Error writing the inventory file. Contents may be incomplete or corrupted.\n"
                     "Error message: {}".format(ex))
        return False


def lookup_urls(regex_l, manifest, inventory, refetch=False):
    """Takes a list of RegExs to match within the manifest, returns a list of tuples with
    (hash, URL) that match the targets and are not in the inventory"""
    selected_targets = []
    # Store whether or not we've found a target in the manifest that matches the requested type
    found_one = False
    for target in manifest.keys():
        # Iterate through the possible targets in the manifest.
        # If any of them match any of the RegExs supplied, add the URL to the
        # return list
        if all(map((lambda regex: re.findall(regex, target)), regex_l)):
            found_one = True
            log("TRACE", "Selected target: {}".format(target))
            target_info = manifest.get(target)
            target_url = target_info.get("url")
            target_hash = target_info.get("repo_hash")
            target_sha256 = target_info.get("sha256_hash")
            filename = os.path.basename(target_url)
            # Check if the same filename and hash appear in the inventory
            if not refetch and inventory.get(target, {}).get("repo_hash", "") == target_hash:
                # We already have this file, we don't need to download it again
                log("INFO", "Target {} is up to date.".format(target))
            else:
                # We don't have that exact file, add it to the list
                selected_targets.append({"target": target,
                                         "repo_hash": target_hash,
                                         "filename": filename,
                                         "url": target_url,
                                         "sha256_hash": target_sha256})
    if not found_one:
        log("INFO", "No targets matching '{}'".format(regex_l))
    return selected_targets


def print_target_list(manifest, args):
    """
    Print a list of targets.
    """
    char_offset = max(len(x) for x in manifest.keys())
    if not args.url_only:
        # Print a couple helpful lines,
        # then print each (Target, URL) pair in the manifest
        log("INFO", "Potential targets in manifest file:\n"
                    "{} : {}".format(
                        "# TARGET".ljust(char_offset),
                        "URL" if args.base_url else "RELATIVE_URL"))
        for key, value in sorted(manifest.items()):
            print("{target} : {base}{relpath}".format(
                target=key.ljust(char_offset),
                base=args.base_url,
                relpath=value["url"]))
    else:
        for manifest_item in iteritems(manifest):
            print(args.base_url+manifest_item[1]["url"])


def download(
        images_url,
        filename,
        buffer_size=_DEFAULT_BUFFER_SIZE,
        print_progress=False,
        download_limit=None
    ):
    """ Run the download, show progress """
    download_limit = download_limit or _DEFAULT_DOWNLOAD_LIMIT
    log("TRACE", "Downloading {} to {}".format(images_url, filename))
    try:
        resp = requests.get(images_url, stream=True, proxies=_PROXIES,
                            headers={'User-Agent': 'UHD Images Downloader'})
    except TypeError:
        # requests library versions pre-4c3b9df6091b65d8c72763222bd5fdefb7231149
        # (Dec.'12) workaround
        resp = requests.get(images_url, prefetch=False, proxies=_PROXIES,
                            headers={'User-Agent': 'UHD Images Downloader'})
    if resp.status_code != 200:
        raise RuntimeError("URL does not exist: {}".format(images_url))
    filesize = float(resp.headers['content-length'])
    if filesize > download_limit:
        if not ask_permission(
                "The file size for this target ({:.1f} MiB) exceeds the "
                "download limit ({:.1f} MiB). Continue downloading?".format(
                    filesize/1024**2, download_limit/1024**2)):
            return 0, 0, ""
    filesize_dl = 0
    base_filename = os.path.basename(filename)
    if print_progress and not sys.stdout.isatty():
        print_progress = False
        log("INFO", "Downloading {}, total size: {} kB".format(
            base_filename, filesize/1000))
    with open(filename, "wb") as temp_file:
        sha256_sum = hashlib.sha256()
        for buff in resp.iter_content(chunk_size=buffer_size):
            if buff:
                temp_file.write(buff)
                filesize_dl += len(buff)
                sha256_sum.update(buff)
            if print_progress:
                status = r"%05d kB / %05d kB (%03d%%) %s" % (
                    int(math.ceil(filesize_dl / 1000.)), int(math.ceil(filesize / 1000.)),
                    int(math.ceil(filesize_dl * 100.) / filesize),
                    base_filename)
                if os.name == "nt":
                    status += chr(8) * (len(status) + 1)
                else:
                    sys.stdout.write("\x1b[2K\r")  # Clear previous line
                sys.stdout.write(status)
                sys.stdout.flush()
    if print_progress:
        print('')
    return filesize, filesize_dl, sha256_sum.hexdigest()


def delete_from_inv(target_info, inventory, images_dir):
    """
    Uses the inventory to delete the contents of the archive file specified by in `target_info`
    """
    target = inventory.get(target_info.get("target"), {})
    target_name = target.get("target")
    log("TRACE", "Removing contents of {} from inventory ({})".format(
        target, target.get("contents", [])))
    dirs_to_delete = []
    # Delete all of the files
    for image_fn in target.get("contents", []):
        image_path = os.path.join(images_dir, image_fn)
        if os.path.isfile(image_path):
            os.remove(image_path)
            log("TRACE", "Deleted {} from inventory".format(image_path))
        elif os.path.isdir(image_path):
            dirs_to_delete.append(image_fn)
        else: # File doesn't exist
            log("WARN", "File {} in inventory does not exist".format(image_path))
    # Then delete all of the (empty) directories
    for dir_path in dirs_to_delete:
        try:
            if os.path.isdir(dir_path):
                os.removedirs(dir_path)
        except os.error as ex:
            log("ERROR", "Failed to delete dir: {}".format(ex))
    inventory.pop(target_name, None)
    return True


def extract(archive_path, images_dir, test_zip=False):
    """
    Extract the contents of `archive_path` into `images_dir`

    Returns a list of files that were extracted.

    If test_zip is set, it will verify the zip file before extracting.
    """
    log("TRACE", "Attempting to extracted files from {}".format(archive_path))
    with zipfile.ZipFile(archive_path) as images_zip:
        # Check that the Zip file is valid, in which case `testzip()` returns
        # None.  If it's bad, that function will return a list of bad files
        try:
            if test_zip and images_zip.testzip():
                log("ERROR", "Could not extract the following invalid Zip file:"
                             " {}".format(archive_path))
                return []
        except OSError:
            log("ERROR", "Could not extract the following invalid Zip file:"
                         " {}".format(archive_path))
            return []
        images_zip.extractall(images_dir)
        archive_namelist = images_zip.namelist()
        log("TRACE", "Extracted files: {}".format(archive_namelist))
        return archive_namelist


def update_target(target_info, temp_dir, images_dir, inventory, args):
    """
    Handle the updating of a single target.
    """
    target_name = target_info.get("target")
    target_sha256 = target_info.get("sha256_hash")
    filename = target_info.get("filename")
    temp_path = os.path.join(temp_dir, filename)
    # Add a trailing slash to make sure that urljoin handles things properly
    full_url = urljoin(args.base_url+'/', target_info.get("url"))
    _, downloaded_size, downloaded_sha256 = download(
        images_url=full_url,
        filename=temp_path,
        buffer_size=args.buffer_size,
        print_progress=(_LOG_LEVEL <= _LOG_LEVELS.get("INFO", 3))
    )
    if downloaded_size == 0:
        log("INFO", "Skipping target: {}".format(target_name))
        return
    log("TRACE", "{} successfully downloaded ({} Bytes)"
        .format(temp_path, downloaded_size))
    # If the SHA256 in the manifest has the value '0', this is a special case
    # and we just skip the verification step
    if target_sha256 == '0':
        log("DEBUG", "Skipping SHA256 check for {}.".format(full_url))
    # If the check fails, print an error and don't unzip the file
    elif downloaded_sha256 != target_sha256:
        log("ERROR", "Downloaded SHA256 does not match manifest for {}!"
            .format(full_url))
        return
        # Note: this skips the --keep option, so we'll never keep image packages
        #       that fail the SHA256 checksum
    ## Now copy the contents to the final destination (the images directory)
    delete_from_inv(target_info, inventory, images_dir)
    if os.path.splitext(temp_path)[1].lower() == '.zip':
        archive_namelist = extract(
            temp_path,
            images_dir,
            args.test)
        if args.keep:
            # If the user wants to keep the downloaded archive,
            # save it to the images directory and add it to the inventory
            shutil.copy(temp_path, images_dir)
            archive_namelist.append(filename)
    else:
        archive_namelist = []
        shutil.copy(temp_path, images_dir)
    ## Update inventory
    inventory[target_name] = {"repo_hash": target_info.get("repo_hash"),
                              "contents": archive_namelist,
                              "filename": filename}


def main():
    """
    Main function; does whatever the user requested (download or list files).

    Returns True on successful execution.
    """
    args = parse_args()
    images_dir = get_images_dir(args)
    log("INFO", "Images destination: {}".format(os.path.abspath(images_dir)))
    try:
        manifest = parse_manifest(get_manifest_raw(args))
        if args.list_targets:
            print_target_list(
                manifest,
                args
            )
            return True
        log("TRACE", "Manifest:\n{}".format(
            "\n".join("{}".format(item) for item in manifest.items())
        ))

        # Read the inventory into a dictionary we can perform lookups on
        if os.path.isfile(args.inventory_location):
            inventory_fn = args.inventory_location
        else:
            inventory_fn = os.path.join(images_dir, _INVENTORY_FILENAME)
        inventory = parse_inventory(inventory_fn=inventory_fn)
        log("TRACE", "Inventory: {}\n{}".format(
            os.path.abspath(inventory_fn),
            "\n".join("{}".format(item) for item in inventory.items())
        ))

        # Determine the URLs to download based on the input regular expressions
        if not args.types:
            types_regex_l = [_DEFAULT_TARGET_REGEX]
        else:
            types_regex_l = args.types

        log("TRACE", "RegExs for target selection: {}".format(types_regex_l))
        targets_info = lookup_urls(types_regex_l, manifest, inventory, args.refetch)
        # Exit early if we don't have anything to download
        if targets_info:
            target_urls = [info.get("url") for info in targets_info]
            log("DEBUG", "URLs to download:\n{}".format(
                "\n".join("{}".format(item) for item in target_urls)
            ))
        else:
            return True

        ## Now download all the images archives into a temp directory
        if args.dry_run:
            for target_info in targets_info:
                log("INFO", "[Dry Run] Fetch target: {}".format(
                    target_info.get("filename")))
            return True
        with TemporaryDirectory() as temp_dir:
            for target_info in targets_info:
                update_target(
                    target_info,
                    temp_dir,
                    images_dir,
                    inventory,
                    args
                )
        ## Update inventory with all the new content
        write_inventory(inventory, inventory_fn)

    except Exception as ex:
        log("ERROR", "Downloader raised an unhandled exception: {ex}\n"
            "You can run this again with the '--verbose' flag to see more information\n"
            "If the problem persists, please email the output to: {contact}"
            .format(contact=_CONTACT, ex=ex))
        # Again, we wait on Windows systems because if this is executed in a
        # window, and immediately fails, the user doesn't have a way to see the
        # error message, and if they're not very savvy, they won't know how to
        # execute this in a shell.
        if not _YES and platform.system() == 'Windows':
            input('Hit Enter to continue.')
        return False
    log("INFO", "Images download complete.")
    return True

# Placing this near the end of the file so we don't clutter the top
_MANIFEST_CONTENTS = """@CMAKE_MANIFEST_CONTENTS@"""
if __name__ == "__main__":
    sys.exit(not main())