#!/usr/bin/env python
#
# Copyright 2012-2014 Ettus Research LLC
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

import sys
import os
import tempfile
import math
import traceback
import shutil
import hashlib
import urllib2
import zipfile

from optparse import OptionParser

_DEFAULT_BUFFER_SIZE      = 8192
_BASE_DIR_STRUCTURE_PARTS = ["share", "uhd", "images"]
_DEFAULT_INSTALL_PATH     = os.path.join("@CMAKE_INSTALL_PREFIX@", *_BASE_DIR_STRUCTURE_PARTS)
_DEFAULT_BASE_URL         = "http://files.ettus.com/binaries/images/"
_AUTOGEN_IMAGES_FILENAME  = "@UHD_IMAGES_DOWNLOAD_SRC@"
_AUTOGEN_IMAGES_CHECKSUM  = "@UHD_IMAGES_MD5SUM@"
_IMAGES_CHECKSUM_TYPE     = "md5"
_CONTACT                  = "support@ettus.com"

def md5Checksum(filePath):
    try:
        with open(filePath, 'rb') as fh:
            m = hashlib.md5()
            while True:
                data = fh.read(_DEFAULT_BUFFER_SIZE)
                if not data:
                    break
                m.update(data)
            return m.hexdigest()
    except Exception, e:
        print "Failed to calculated MD5 sum of: %s (%s)" % (filePath, e)
        raise e

_checksum_fns = {
    'md5': md5Checksum
}

class temporary_directory():
    def __enter__(self):
        try:
            self.name = tempfile.mkdtemp()
            return self.name
        except Exception, e:
            print "Failed to create a temporary directory (%s)" % (e)
            raise e

    # Can return 'True' to suppress incoming exception
    def __exit__(self, type, value, traceback):
        try:
            shutil.rmtree(self.name)
        except Exception, e:
            print "Could not delete temporary directory: %s (%s)" % (self.name, e)

class uhd_images_downloader():
    def __init__(self):
        pass

    def download(self, images_url, filename, buffer_size=_DEFAULT_BUFFER_SIZE, print_progress=False):
        """ Run the download, show progress """
        opener = urllib2.build_opener()
        opener.add_headers = [('User-Agent', 'UHD Images Downloader')]
        u = opener.open(images_url)
        meta = u.info()
        filesize = float(meta.getheaders("Content-Length")[0])
        filesize_dl = 0
        with open(filename, "wb") as f:
            while True:
                buff = u.read(buffer_size)
                if not buff:
                    break
                f.write(buff)
                filesize_dl += len(buff)
                if print_progress:
                    status = r"%05d kB / %05d kB (%03d%%)" % (int(math.ceil(filesize_dl/1000.)), int(math.ceil(filesize/1000.)), int(math.ceil(filesize_dl*100.)/filesize))
                    status += chr(8)*(len(status)+1)
                    print status,
        if print_progress:
            print
        return (filesize, filesize_dl)

    def check_directories(self, dirs, print_progress=False):
        if dirs is None or dirs == "":
            dirs = "."
        dirs = os.path.abspath(dirs)

        def _check_part(head, tail=None):
            if print_progress:
                print "Checking: %s" % (head)
            if tail is not None and tail == "":
                return True
            if not os.path.exists(head):
                if print_progress:
                    print "Does not exist: %s" % (head)
                return _check_part(*os.path.split(head))
            if not os.path.isdir(head):
                if print_progress:
                    print "Is not a directory: %s" % (head)
                return (False, head)
            if not os.access(head, os.W_OK):
                if print_progress:
                    print "Write permission denied on: %s" % (head)
                return (False, head)
            if print_progress:
                print "Write permission granted on: %s" % (head)
            return (True, head)

        return _check_part(dirs)

    def validate_checksum(self, checksum_fn, file_path, expecting, print_progress=False):
        if checksum_fn is None:
            return (True, "")
        calculated_checksum = checksum_fn(file_path)
        if (expecting is not None) and (expecting != "") and calculated_checksum != expecting:
            return (False, calculated_checksum)
        return (True, calculated_checksum)

    def extract_images_archive(self, archive_path, destination=None, print_progress=False):
        if not os.path.exists(archive_path):
            if print_progress:
                print "Path does not exist: %s" % (archive_path)
            raise Exception("path does not exist: %s" % (archive_path))
        if print_progress:
            print "Archive path: %s" % (archive_path)
        (head, tail) = os.path.split(archive_path)

        if not os.access(head, os.W_OK):
            if print_progress:
                print "Write access denied on: %s" % (head)
            raise Exception("write access denied on: %s" % (head))

        (root, ext) = os.path.splitext(tail)
        temp_dir = os.path.join(head, root)

        if print_progress:
            print "Temporary extraction location: %s" % (temp_dir)

        if os.path.exists(temp_dir):
            if print_progress:
                print "Deleting existing location: %s" % (temp_dir)
            shutil.rmtree(temp_dir)

        if print_progress:
            print "Creating directory: %s" % (temp_dir)
        os.mkdir(temp_dir)

        if print_progress:
            print "Extracting archive %s to %s" % (archive_path, temp_dir)

        images_zip = zipfile.ZipFile(archive_path)
        images_zip.extractall(temp_dir)
        images_zip.close()

        return temp_dir

    def install_images(self, source, dest, keep=False, print_progress=False):
        if not os.path.exists(source):
            if print_progress:
                print "Source path does not exist: %s" % (source)
            return

        if keep:
            if print_progress:
                print "Not wiping directory tree (existing files will be overwritten): %s" % (dest)
        elif os.path.exists(dest):
            if print_progress:
                print "Deleting directory tree: %s" % (dest)
            shutil.rmtree(dest)

        (head, tail) = os.path.split(source)

        if print_progress:
            print "Source install path: %s" % (source)

        uhd_source = os.path.join(source, tail, *_BASE_DIR_STRUCTURE_PARTS)

        if print_progress:
            print "Copying files from: %s" % (uhd_source)
            print "Copying files to:   %s" % (dest)

        if keep:
            # mgrant @ http://stackoverflow.com/questions/12683834/how-to-copy-directory-recursively-in-python-and-overwrite-all
            def _recursive_overwrite(src, dest, ignore=None):
                if os.path.isdir(src):
                    if not os.path.isdir(dest):
                        os.makedirs(dest)
                    files = os.listdir(src)
                    if ignore is not None:
                        ignored = ignore(src, files)
                    else:
                        ignored = set()
                    for f in files:
                        if f not in ignored:
                            _recursive_overwrite(os.path.join(src, f), os.path.join(dest, f), ignore)
                else:
                    shutil.copyfile(src, dest)

            _recursive_overwrite(uhd_source, dest)
        else:
            shutil.copytree(uhd_source, dest)

def main():
    ### Set defaults from env variables
    if os.environ.get("UHD_IMAGES_DIR") != None and os.environ.get("UHD_IMAGES_DIR") != "":
        default_images_dir = os.environ.get("UHD_IMAGES_DIR")
        print "UHD_IMAGES_DIR environment variable is set.\nDefault install location: {0}".format(default_images_dir)
    else:
        default_images_dir = _DEFAULT_INSTALL_PATH
    if os.environ.get("UHD_IMAGES_BASE_URL") != None and os.environ.get("UHD_IMAGES_BASE_URL") != "":
        default_base_url = os.environ.get("UHD_IMAGES_BASE_URL")
        print "UHD_IMAGES_BASE_URL environment variable is set.\nDefault base URL: {0}".format(default_base_url)
    else:
        default_base_url = _DEFAULT_BASE_URL

    ### Setup argument parser and parse
    parser = OptionParser()
    parser.add_option("-i", "--install-location",   type="string",          default=default_images_dir,
                        help="Set custom install location for images [default=%default]")
    parser.add_option("--buffer-size",              type="int",             default=_DEFAULT_BUFFER_SIZE,
                        help="Set download buffer size [default=%default]")
    parser.add_option("-b", "--base-url",           type="string",          default=default_base_url,
                        help="Set base URL for images download location [default=%default]")
    parser.add_option("-f", "--filename",           type="string",          default=_AUTOGEN_IMAGES_FILENAME,
                        help="Set images archive filename [default=%default]")
    parser.add_option("-c", "--checksum",           type="string",          default=_AUTOGEN_IMAGES_CHECKSUM,
                        help="Validate images archive against this checksum (blank to skip) [default=%default]")
    parser.add_option("-t", "--checksum-type",      type="string",          default=_IMAGES_CHECKSUM_TYPE,
                        help=("Select checksum hash function (options: %s) [default=%%default]" % (",".join(_checksum_fns.keys()))))
    parser.add_option("-k", "--keep",               action="store_true",    default=False,
                        help="Do not clear images directory before extracting new files [default=%default]")
    parser.add_option("-v", "--verbose",            action="store_true",    default=False,
                        help="Enable verbose output [default=%default]")
    (options, args) = parser.parse_args()
    if options.buffer_size <= 0:
        print "Invalid buffer size: %s" % (options.buffer_size)
        return 1

    ### Select checksum algorithm (MD5)
    checksum_fn = None
    if options.checksum != "":
        options.checksum_type = options.checksum_type.lower()
        if not _checksum_fns.has_key(options.checksum_type):
            print "Not a supported checksum function: %s" % (options.checksum_type)
            return 1
        checksum_fn = _checksum_fns[options.checksum_type]

    ### Check if base URL is a local dir or off the webs
    images_dir = os.path.abspath(options.install_location)  # This will use the current working directory if it's not absolute
    images_url = None
    if options.base_url.find('http') == 0:
        base_url_is_local = False
        if options.base_url[-1] != '/':
            options.base_url += '/'
        images_url = options.base_url + options.filename
    else:
        base_url_is_local = True

    if options.verbose:
        print "Requested install location: %s" % (options.install_location)
        print "Images base URL:            %s" % (options.base_url)
        print "Images filename:            %s" % (options.filename)
        print "Images checksum:            %s (%s)" % (options.checksum, _IMAGES_CHECKSUM_TYPE)
        print "Final install location:     %s" % (images_dir)
        print "Copying locally:            {0}".format("Yes" if base_url_is_local else "No")
    else:
        print "Images destination:      %s" % (images_dir)

    ### Download or copy
    downloader = uhd_images_downloader()
    try:
        (access, last_path) = downloader.check_directories(images_dir, print_progress=options.verbose)
        if not access:
            print "You do not have sufficient permissions to write to: %s" % (last_path)
            print "Are you root?"
            return 1
        with temporary_directory() as temp_dir:
            if options.verbose:
                print "Using temporary directory: %s" % (temp_dir)
            temp_images_dest = os.path.join(temp_dir, options.filename)
            if not base_url_is_local:
                print "Downloading images from: {0}".format(images_url)
                print "Downloading images to:   {0}".format(temp_images_dest)
                (reported_size, downloaded_size) = downloader.download(
                        images_url=images_url,
                        filename=temp_images_dest,
                        buffer_size=options.buffer_size,
                        print_progress=True
                )
                if options.verbose:
                    print "Downloaded %d of %d bytes" % (downloaded_size, reported_size)
            else:
                local_images_pkg = os.path.join(options.base_url, options.filename)
                print "Copying images from:     {0}".format(local_images_pkg)
                if not os.path.isfile(local_images_pkg):
                    print "[ERROR] No such file."
                    return 1
                shutil.copyfile(local_images_pkg, temp_images_dest)
            (checksum_match, calculated_checksum) = downloader.validate_checksum(
                    checksum_fn,
                    temp_images_dest,
                    options.checksum,
                    print_progress=options.verbose
            )
            if options.verbose:
                print "Calculated checksum: %s" % (calculated_checksum)
            if checksum_match:
                if options.verbose:
                    if options.checksum == "":
                        print "Ignoring checksum"
                    else:
                        print "Checksum OK"
                try:
                    extract_path = downloader.extract_images_archive(temp_images_dest, print_progress=options.verbose)
                    if options.verbose:
                        print "Image archive extracted to: %s" % (extract_path)
                    downloader.install_images(extract_path, images_dir, options.keep, print_progress=options.verbose)
                    if options.verbose:
                        print "Cleaning up temp location: %s" % (extract_path)
                    shutil.rmtree(extract_path)
                    print
                    print "Images successfully installed to: %s" % (images_dir)
                except Exception, e:
                    print "Failed to install image archive: %s" % (e)
                    print "This is usually a permissions problem."
                    print "Please check your file system access rights and try again."
                    if options.verbose:
                        traceback.print_exc()
                    else:
                        print "You can run this again with the '--verbose' flag to see more information"
                    print "If the problem persists, please email the output to: %s" % (_CONTACT)
            else:
                print "Checksum of downloaded file is not correct (not installing - see options to override)"
                print "Expected:   %s" % (options.checksum)
                print "Calculated: %s" % (calculated_checksum)
                print "Please try downloading again."
                print "If the problem persists, please email the output to: %s" % (_CONTACT)
    except KeyboardInterrupt:
        print
        print "Cancelled at user request"
    except Exception, e:
        print "Downloader raised an unhandled exception: %s" % (e)
        if options.verbose:
            traceback.print_exc()
        else:
            print "You can run this again with the '--verbose' flag to see more information"
        print "If the problem persists, please email the output to: %s" % (_CONTACT)
        return 1
    return 0

if __name__ == "__main__":
    sys.exit(main())