#!/usr/bin/env python
#
# Copyright 2012-2015 Ettus Research LLC
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

import sys
import os
import tempfile
import math
import traceback
import shutil
import hashlib
import requests
import zipfile

from optparse import OptionParser

_DEFAULT_BUFFER_SIZE      = 8192
_BASE_DIR_STRUCTURE_PARTS = ["share", "uhd", "images"]
_DEFAULT_INSTALL_PATH     = os.path.join("@CMAKE_INSTALL_PREFIX@", *_BASE_DIR_STRUCTURE_PARTS)
_DEFAULT_BASE_URL         = "http://files.ettus.com/binaries/images/"
_AUTOGEN_IMAGES_FILENAME  = "@UHD_IMAGES_DOWNLOAD_SRC@"
_AUTOGEN_IMAGES_CHECKSUM  = "@UHD_IMAGES_MD5SUM@"
_IMAGES_CHECKSUM_TYPE     = "md5"
_CONTACT                  = "support@ettus.com"

def md5Checksum(filePath):
    try:
        with open(filePath, 'rb') as fh:
            m = hashlib.md5()
            while True:
                data = fh.read(_DEFAULT_BUFFER_SIZE)
                if not data:
                    break
                m.update(data)
            return m.hexdigest()
    except Exception as e:
        print("Failed to calculated MD5 sum of: %s (%s)" % (filePath, e))
        raise e

_checksum_fns = {
    'md5': md5Checksum
}

class temporary_directory():
    def __enter__(self):
        try:
            self.name = tempfile.mkdtemp()
            return self.name
        except Exception as e:
            print("Failed to create a temporary directory (%s)" % (e))
            raise e

    # Can return 'True' to suppress incoming exception
    def __exit__(self, type, value, traceback):
        try:
            shutil.rmtree(self.name)
        except Exception as e:
            print("Could not delete temporary directory: %s (%s)" % (self.name, e))

class uhd_images_downloader():
    def __init__(self):
        pass

    def download(self, images_url, filename, buffer_size=_DEFAULT_BUFFER_SIZE, print_progress=False):
        """ Run the download, show progress """
        try:
            r = requests.get(images_url, stream=True, headers={'User-Agent': 'UHD Images Downloader'})
        except TypeError as te:
            ## requests library versions pre-4c3b9df6091b65d8c72763222bd5fdefb7231149 (Dec.'12) workaround
            r = requests.get(images_url, prefetch=False, headers={'User-Agent': 'UHD Images Downloader'})
        filesize = float(r.headers['content-length'])
        filesize_dl = 0
        with open(filename, "wb") as f:
            for buff in r.iter_content(chunk_size=buffer_size):
                if buff:
                    f.write(buff)
                    filesize_dl += len(buff)
                if print_progress:
                    status = r"%05d kB / %05d kB (%03d%%)" % (int(math.ceil(filesize_dl/1000.)), int(math.ceil(filesize/1000.)), int(math.ceil(filesize_dl*100.)/filesize))
                    if os.name == "nt":
                        status += chr(8)*(len(status)+1)
                    else:
                        sys.stdout.write("\x1b[2K\r") #Clear previos line
                    sys.stdout.write(status)
                    sys.stdout.flush()
        if print_progress:
            print('')
        return (filesize, filesize_dl)

    def check_directories(self, dirs, print_progress=False):
        if dirs is None or dirs == "":
            dirs = "."
        dirs = os.path.abspath(dirs)

        def _check_part(head, tail=None):
            if print_progress:
                print("Checking: %s" % (head))
            if tail is not None and tail == "":
                return True
            if not os.path.exists(head):
                if print_progress:
                    print("Does not exist: %s" % (head))
                return _check_part(*os.path.split(head))
            if not os.path.isdir(head):
                if print_progress:
                    print("Is not a directory: %s" % (head))
                return (False, head)
            if not os.access(head, os.W_OK):
                if print_progress:
                    print("Write permission denied on: %s" % (head))
                return (False, head)
            if print_progress:
                print("Write permission granted on: %s" % (head))
            return (True, head)

        return _check_part(dirs)

    def validate_checksum(self, checksum_fn, file_path, expecting, print_progress=False):
        if checksum_fn is None:
            return (True, "")
        calculated_checksum = checksum_fn(file_path)
        if (expecting is not None) and (expecting != "") and calculated_checksum != expecting:
            return (False, calculated_checksum)
        return (True, calculated_checksum)

    def extract_images_archive(self, archive_path, destination=None, print_progress=False):
        if not os.path.exists(archive_path):
            if print_progress:
                print("Path does not exist: %s" % (archive_path))
            raise Exception("path does not exist: %s" % (archive_path))
        if print_progress:
            print("Archive path: %s" % (archive_path))
        (head, tail) = os.path.split(archive_path)

        if not os.access(head, os.W_OK):
            if print_progress:
                print("Write access denied on: %s" % (head))
            raise Exception("write access denied on: %s" % (head))

        (root, ext) = os.path.splitext(tail)
        temp_dir = os.path.join(head, root)

        if print_progress:
            print("Temporary extraction location: %s" % (temp_dir))

        if os.path.exists(temp_dir):
            if print_progress:
                print("Deleting existing location: %s" % (temp_dir))
            shutil.rmtree(temp_dir)

        if print_progress:
            print("Creating directory: %s" % (temp_dir))
        os.mkdir(temp_dir)

        if print_progress:
            print("Extracting archive %s to %s" % (archive_path, temp_dir))

        images_zip = zipfile.ZipFile(archive_path)
        images_zip.extractall(temp_dir)
        images_zip.close()

        return temp_dir

    def install_images(self, source, dest, keep=False, print_progress=False):
        if not os.path.exists(source):
            if print_progress:
                print("Source path does not exist: %s" % (source))
            return

        if keep:
            if print_progress:
                print("Not wiping directory tree (existing files will be overwritten): %s" % (dest))
        elif os.path.exists(dest):
            if print_progress:
                print("Deleting directory tree: %s" % (dest))
            shutil.rmtree(dest)

        (head, tail) = os.path.split(source)

        if print_progress:
            print("Source install path: %s" % (source))

        uhd_source = os.path.join(source, tail, *_BASE_DIR_STRUCTURE_PARTS)

        if print_progress:
            print("Copying files from: %s" % (uhd_source))
            print("Copying files to:   %s" % (dest))

        if keep:
            # mgrant @ http://stackoverflow.com/questions/12683834/how-to-copy-directory-recursively-in-python-and-overwrite-all
            def _recursive_overwrite(src, dest, ignore=None):
                if os.path.isdir(src):
                    if not os.path.isdir(dest):
                        os.makedirs(dest)
                    files = os.listdir(src)
                    if ignore is not None:
                        ignored = ignore(src, files)
                    else:
                        ignored = set()
                    for f in files:
                        if f not in ignored:
                            _recursive_overwrite(os.path.join(src, f), os.path.join(dest, f), ignore)
                else:
                    shutil.copyfile(src, dest)

            _recursive_overwrite(uhd_source, dest)
        else:
            shutil.copytree(uhd_source, dest)

def main():
    ### Set defaults from env variables
    if os.environ.get("UHD_IMAGES_DIR") != None and os.environ.get("UHD_IMAGES_DIR") != "":
        default_images_dir = os.environ.get("UHD_IMAGES_DIR")
        print("UHD_IMAGES_DIR environment variable is set.\nDefault install location: {0}".format(default_images_dir))
    else:
        default_images_dir = _DEFAULT_INSTALL_PATH
    if os.environ.get("UHD_IMAGES_BASE_URL") != None and os.environ.get("UHD_IMAGES_BASE_URL") != "":
        default_base_url = os.environ.get("UHD_IMAGES_BASE_URL")
        print("UHD_IMAGES_BASE_URL environment variable is set.\nDefault base URL: {0}".format(default_base_url))
    else:
        default_base_url = _DEFAULT_BASE_URL

    ### Setup argument parser and parse
    parser = OptionParser()
    parser.add_option("-i", "--install-location",   type="string",          default=default_images_dir,
                        help="Set custom install location for images [default=%default]")
    parser.add_option("--buffer-size",              type="int",             default=_DEFAULT_BUFFER_SIZE,
                        help="Set download buffer size [default=%default]")
    parser.add_option("-b", "--base-url",           type="string",          default=default_base_url,
                        help="Set base URL for images download location [default=%default]")
    parser.add_option("-f", "--filename",           type="string",          default=_AUTOGEN_IMAGES_FILENAME,
                        help="Set images archive filename [default=%default]")
    parser.add_option("-c", "--checksum",           type="string",          default=_AUTOGEN_IMAGES_CHECKSUM,
                        help="Validate images archive against this checksum (blank to skip) [default=%default]")
    parser.add_option("-t", "--checksum-type",      type="string",          default=_IMAGES_CHECKSUM_TYPE,
                        help=("Select checksum hash function (options: %s) [default=%%default]" % (",".join(list(_checksum_fns.keys())))))
    parser.add_option("-k", "--keep",               action="store_true",    default=False,
                        help="Do not clear images directory before extracting new files [default=%default]")
    parser.add_option("-v", "--verbose",            action="store_true",    default=False,
                        help="Enable verbose output [default=%default]")
    parser.add_option("--force-delete",             action="store_true",    default=False,
                        help="Delete all files in the target images directory without prompting [default=%default]")
    (options, args) = parser.parse_args()
    if options.buffer_size <= 0:
        print("Invalid buffer size: %s" % (options.buffer_size))
        return 1

    ### Select checksum algorithm (MD5)
    checksum_fn = None
    if options.checksum != "":
        options.checksum_type = options.checksum_type.lower()
        if options.checksum_type not in _checksum_fns:
            print("Not a supported checksum function: %s" % (options.checksum_type))
            return 1
        checksum_fn = _checksum_fns[options.checksum_type]

    ### Check if base URL is a local dir or off the webs
    images_dir = os.path.abspath(options.install_location)  # This will use the current working directory if it's not absolute
    images_url = None
    if options.base_url.find('http') == 0:
        base_url_is_local = False
        if options.base_url[-1] != '/':
            options.base_url += '/'
        images_url = options.base_url + options.filename
    else:
        base_url_is_local = True

    if options.verbose:
        print("Requested install location: %s" % (options.install_location))
        print("Images base URL:            %s" % (options.base_url))
        print("Images filename:            %s" % (options.filename))
        print("Images checksum:            %s (%s)" % (options.checksum, _IMAGES_CHECKSUM_TYPE))
        print("Final install location:     %s" % (images_dir))
        print("Copying locally:            {0}".format("Yes" if base_url_is_local else "No"))
    else:
        print("Images destination:      %s" % (images_dir))

    ### Check contradictory arguments
    if options.force_delete and options.keep:
        print("Error: Keep and force delete options contradict.\n")
        parser.print_help()
        return 1

    ### Prevent accidental file deletion
    if options.install_location != default_images_dir and options.force_delete == False and options.keep != True:
        print("Custom install location specified, defaulting to overwriting only image files.\n"
            "Use \'--force-delete\' to clean the target directory first.")
        options.keep = True

    ### Download or copy
    downloader = uhd_images_downloader()
    try:
        (access, last_path) = downloader.check_directories(images_dir, print_progress=options.verbose)
        if not access:
            print("You do not have sufficient permissions to write to: %s" % (last_path))
            print("Are you root?")
            return 1
        with temporary_directory() as temp_dir:
            if options.verbose:
                print("Using temporary directory: %s" % (temp_dir))
            temp_images_dest = os.path.join(temp_dir, options.filename)
            if not base_url_is_local:
                print("Downloading images from: {0}".format(images_url))
                print("Downloading images to:   {0}".format(temp_images_dest))
                (reported_size, downloaded_size) = downloader.download(
                        images_url=images_url,
                        filename=temp_images_dest,
                        buffer_size=options.buffer_size,
                        print_progress=True
                )
                if options.verbose:
                    print("Downloaded %d of %d bytes" % (downloaded_size, reported_size))
            else:
                local_images_pkg = os.path.join(options.base_url, options.filename)
                print("Copying images from:     {0}".format(local_images_pkg))
                if not os.path.isfile(local_images_pkg):
                    print("[ERROR] No such file.")
                    return 1
                shutil.copyfile(local_images_pkg, temp_images_dest)
            (checksum_match, calculated_checksum) = downloader.validate_checksum(
                    checksum_fn,
                    temp_images_dest,
                    options.checksum,
                    print_progress=options.verbose
            )
            if options.verbose:
                print("Calculated checksum: %s" % (calculated_checksum))
            if checksum_match:
                if options.verbose:
                    if options.checksum == "":
                        print("Ignoring checksum")
                    else:
                        print("Checksum OK")
                try:
                    extract_path = downloader.extract_images_archive(temp_images_dest, print_progress=options.verbose)
                    if options.verbose:
                        print("Image archive extracted to: %s" % (extract_path))
                    downloader.install_images(extract_path, images_dir, options.keep, print_progress=options.verbose)
                    if options.verbose:
                        print("Cleaning up temp location: %s" % (extract_path))
                    shutil.rmtree(extract_path)
                    print("\nImages successfully installed to: %s" % (images_dir))
                except Exception as e:
                    print("Failed to install image archive: %s" % (e))
                    print("This is usually a permissions problem.")
                    print("Please check your file system access rights and try again.")
                    if options.verbose:
                        traceback.print_exc()
                    else:
                        print("You can run this again with the '--verbose' flag to see more information")
                    print("If the problem persists, please email the output to: %s" % (_CONTACT))
            else:
                print("Checksum of downloaded file is not correct (not installing - see options to override)")
                print("Expected:   %s" % (options.checksum))
                print("Calculated: %s" % (calculated_checksum))
                print("Please try downloading again.")
                print("If the problem persists, please email the output to: %s" % (_CONTACT))
    except KeyboardInterrupt:
        print("\nCancelled at user request")
    except Exception as e:
        print("Downloader raised an unhandled exception: %s" % (e))
        if options.verbose:
            traceback.print_exc()
        else:
            print("You can run this again with the '--verbose' flag to see more information")
        print("If the problem persists, please email the output to: %s" % (_CONTACT))
        return 1
    return 0

if __name__ == "__main__":
    sys.exit(main())