#!/usr/bin/env python
#
# Copyright 2012-2013 Ettus Research LLC
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

import sys, os, string, tempfile, math
import traceback
import shutil, hashlib, urllib2, zipfile

from optparse import OptionParser

_DEFAULT_BUFFER_SIZE      = 8192
_BASE_DIR_STRUCTURE_PARTS = ["share", "uhd", "images"]
_BASE_DIR_STRUCTURE       = os.path.join(_BASE_DIR_STRUCTURE_PARTS)
_DEFAULT_INSTALL_PATH     = os.path.join("@CMAKE_INSTALL_PREFIX@", *_BASE_DIR_STRUCTURE)
_AUTOGEN_IMAGES_SOURCE    = "@UHD_IMAGES_DOWNLOAD_SRC@"
_AUTOGEN_IMAGES_CHECKSUM  = "@UHD_IMAGES_MD5SUM@"
_IMAGES_CHECKSUM_TYPE     = "md5"
_CONTACT                  = "support@ettus.com"

def md5Checksum(filePath):
    try:
        with open(filePath, 'rb') as fh:
            m = hashlib.md5()
            while True:
                data = fh.read(_DEFAULT_BUFFER_SIZE)
                if not data:
                    break
                m.update(data)
            return m.hexdigest()
    except Exception, e:
        print "Failed to calculated MD5 sum of: %s (%s)" % (filePath, e)
        raise e

_checksum_fns = {
    'md5': md5Checksum
}

class temporary_directory():
    def __enter__(self):
        try:
            self.name = tempfile.mkdtemp()
            return self.name
        except Exception, e:
            print "Failed to create a temporary directory (%s)" % (e)
            raise e
    
    # Can return 'True' to suppress incoming exception
    def __exit__(self, type, value, traceback):
        try:
            shutil.rmtree(self.name)
        except Exception, e:
            print "Could not delete temporary directory: %s (%s)" % (self.name, e)

class uhd_images_downloader():
    def __init__(self):
        pass
    
    def download(self, images_url, filename, buffer_size=_DEFAULT_BUFFER_SIZE, print_progress=False):
        opener = urllib2.build_opener()
        opener.add_headers = [('User-Agent', 'UHD Images Downloader')]
        u = opener.open(images_url)
        meta = u.info()
        filesize = float(meta.getheaders("Content-Length")[0])
        
        filesize_dl = 0
        
        with open(filename, "wb") as f:
            while True:
                buffer = u.read(buffer_size)
                if not buffer:
                    break
                
                f.write(buffer)

                filesize_dl += len(buffer)
                
                if print_progress:
                    status = r"%05d kB / %05d kB (%03d%%)" % (int(math.ceil(filesize_dl/1000.)), int(math.ceil(filesize/1000.)), int(math.ceil(filesize_dl*100.)/filesize))
                    status += chr(8)*(len(status)+1)
                    print status,
        
        if print_progress:
            print
        
        return (filesize, filesize_dl)
    
    def check_directories(self, dirs, print_progress=False):
        if dirs is None or dirs == "":
            dirs = "."
        dirs = os.path.abspath(dirs)
        
        def _check_part(head, tail=None):
            if print_progress:
                print "Checking: %s" % (head)
            if tail is not None and tail == "":
                return True
            if not os.path.exists(head):
                if print_progress:
                    print "Does not exist: %s" % (head)
                return _check_part(*os.path.split(head))
            if not os.path.isdir(head):
                if print_progress:
                    print "Is not a directory: %s" % (head)
                return (False, head)
            if not os.access(head, os.W_OK):
                if print_progress:
                    print "Write permission denied on: %s" % (head)
                return (False, head)
            if print_progress:
                print "Write permission granted on: %s" % (head)
            return (True, head)
        
        return _check_part(dirs)
    
    def validate_checksum(self, checksum_fn, file_path, expecting, print_progress=False):
        if checksum_fn is None:
            return (True, "")
        
        calculated_checksum = checksum_fn(file_path)
        
        if (expecting is not None) and (expecting != "") and calculated_checksum != expecting:
            return (False, calculated_checksum)
        
        return (True, calculated_checksum)
    
    def extract_images_archive(self, archive_path, destination=None, print_progress=False):
        if not os.path.exists(archive_path):
            if print_progress:
                print "Path does not exist: %s" % (archive_path)
            raise Exception("path does not exist: %s" % (archive_path))
        
        if print_progress:
            print "Archive path: %s" % (archive_path)
        
        (head, tail) = os.path.split(archive_path)
        
        if not os.access(head, os.W_OK):
            if print_progress:
                print "Write access denied on: %s" % (head)
            raise Exception("write access denied on: %s" % (head))
        
        (root, ext) = os.path.splitext(tail)
        temp_dir = os.path.join(head, root)
        
        if print_progress:
            print "Temporary extraction location: %s" % (temp_dir)
        
        if os.path.exists(temp_dir):
            if print_progress:
                print "Deleting existing location: %s" % (temp_dir)
            shutil.rmtree(temp_dir)
        
        if print_progress:
            print "Creating directory: %s" % (temp_dir)
        os.mkdir(temp_dir)
        
        if print_progress:
            print "Extracting archive %s to %s" % (archive_path, temp_dir)
        
        images_zip = zipfile.ZipFile(archive_path)
        images_zip.extractall(temp_dir)
        images_zip.close()
        
        return temp_dir
    
    def install_images(self, source, dest, keep=False, print_progress=False):
        if not os.path.exists(source):
            if print_progress:
                print "Source path does not exist: %s" % (source)
            return
        
        if keep:
            if print_progress:
                print "Not wiping directory tree (existing files will be overwritten): %s" % (dest)
        elif os.path.exists(dest):
            if print_progress:
                print "Deleting directory tree: %s" % (dest)
            shutil.rmtree(dest)
        
        (head, tail) = os.path.split(source)
        
        if print_progress:
            print "Source install path: %s" % (source)
        
        uhd_source = os.path.join(source, tail, *_BASE_DIR_STRUCTURE_PARTS)
        
        if print_progress:
            print "Copying files from: %s" % (uhd_source)
            print "Copying files to:   %s" % (dest)
        
        if keep:
            # mgrant @ http://stackoverflow.com/questions/12683834/how-to-copy-directory-recursively-in-python-and-overwrite-all
            def _recursive_overwrite(src, dest, ignore=None):
                if os.path.isdir(src):
                    if not os.path.isdir(dest):
                        os.makedirs(dest)
                    files = os.listdir(src)
                    if ignore is not None:
                        ignored = ignore(src, files)
                    else:
                        ignored = set()
                    for f in files:
                        if f not in ignored:
                            _recursive_overwrite(os.path.join(src, f), os.path.join(dest, f), ignore)
                else:
                    shutil.copyfile(src, dest)
            
            _recursive_overwrite(uhd_source, dest)
        else:
            shutil.copytree(uhd_source, dest)

def main():
    if os.environ.get("UHD_IMAGES_DIR") != None and os.environ.get("UHD_IMAGES_DIR") != "":
        default_images_dir = os.environ.get("UHD_IMAGES_DIR")
        print "UHD_IMAGES_DIR environment variable is set. Default install location: %s" % default_images_dir
    else:
        default_images_dir = _DEFAULT_INSTALL_PATH
    
    parser = OptionParser()
    parser.add_option("-i", "--install-location",   type="string",          default=default_images_dir,
                        help="Set custom install location for images [default=%default]")
    parser.add_option("--buffer-size",              type="int",             default=_DEFAULT_BUFFER_SIZE,
                        help="Set download buffer size [default=%default]")
    parser.add_option("-u", "--url",                type="string",          default=_AUTOGEN_IMAGES_SOURCE,
                        help="Set images download location [default=%default]")
    parser.add_option("-c", "--checksum",           type="string",          default=_AUTOGEN_IMAGES_CHECKSUM,
                        help="Validate images archive against this checksum (blank to skip) [default=%default]")
    parser.add_option("-t", "--checksum-type",      type="string",          default=_IMAGES_CHECKSUM_TYPE,
                        help=("Select checksum hash function (options: %s) [default=%%default]" % (",".join(_checksum_fns.keys()))))
    parser.add_option("-k", "--keep",               action="store_true",    default=False,
                        help="Do not clear images directory before extracting new files [default=%default]")
    parser.add_option("-v", "--verbose",            action="store_true",    default=False,
                        help="Enable verbose output [default=%default]")
    
    (options, args) = parser.parse_args()
    
    if options.buffer_size <= 0:
        print "Invalid buffer size: %s" % (options.buffer_size)
        return 1
    
    checksum_fn = None
    if options.checksum != "":
        options.checksum_type = options.checksum_type.lower()
        if not _checksum_fns.has_key(options.checksum_type):
            print "Not a supported checksum function: %s" % (options.checksum_type)
            return 1
        checksum_fn = _checksum_fns[options.checksum_type]
    
    url_parts = options.url.split("/")
    if len(url_parts) <= 1 or url_parts[-1] == "":
        print "Not a valid URL: %s" % (options.url)
        return 1
    images_filename = url_parts[-1]
    
    images_dir = os.path.abspath(options.install_location)  # This will use the current working directory if it's not absolute
    
    if options.verbose:
        print "Requested install location: %s" % (options.install_location)
        print "Images source:              %s" % (options.url)
        print "Images filename:            %s" % (images_filename)
        print "Images checksum:            %s (%s)" % (options.checksum, _IMAGES_CHECKSUM_TYPE)
        print "Final install location:     %s" % (images_dir)
    else:
        print "Images destination:      %s" % (images_dir)
    
    downloader = uhd_images_downloader()
    
    try:
        (access, last_path) = downloader.check_directories(images_dir, print_progress=options.verbose)
        if access:
            with temporary_directory() as temp_dir:
                if options.verbose:
                    print "Using temporary directory: %s" % (temp_dir)
                
                print "Downloading images from: %s" % options.url
                
                temp_images_dest = os.path.join(temp_dir, images_filename)
                
                print "Downloading images to:   %s" % (temp_images_dest)
                
                (reported_size, downloaded_size) = downloader.download(images_url=options.url, filename=temp_images_dest, buffer_size=options.buffer_size, print_progress=True)
                
                if options.verbose:
                    print "Downloaded %d of %d bytes" % (downloaded_size, reported_size)
                
                (checksum_match, calculated_checksum) = downloader.validate_checksum(checksum_fn, temp_images_dest, options.checksum, print_progress=options.verbose)
                
                if options.verbose:
                    print "Calculated checksum: %s" % (calculated_checksum)
                
                if checksum_match:
                    if options.verbose:
                        if options.checksum == "":
                            print "Ignoring checksum"
                        else:
                            print "Checksum OK"
                    
                    try:
                        extract_path = downloader.extract_images_archive(temp_images_dest, print_progress=options.verbose)
                        
                        if options.verbose:
                            print "Image archive extracted to: %s" % (extract_path)
                        
                        downloader.install_images(extract_path, images_dir, options.keep, print_progress=options.verbose)
                        
                        print
                        print "Images successfully installed to: %s" % (images_dir)
                    except Exception, e:
                        print "Failed to install image archive: %s" % (e)
                        print "This is usually a permissions problem."
                        print "Please check your file system access rights and try again."
                        
                        if options.verbose:
                            traceback.print_exc()
                        else:
                            print "You can run this again with the '--verbose' flag to see more information"
                        print "If the problem persists, please email the output to: %s" % (_CONTACT)
                else:
                    print "Checksum of downloaded file is not correct (not installing - see options to override)"
                    print "Expected:   %s" % (options.checksum)
                    print "Calculated: %s" % (calculated_checksum)
                    print "Please try downloading again."
                    print "If the problem persists, please email the output to: %s" % (_CONTACT)
        else:
            print "You do not have sufficient permissions to write to: %s" % (last_path)
            print "Are you root?"
    except KeyboardInterrupt:
        print
        print "Cancelled at user request"
    except Exception, e:
        print "Downloader raised an unhandled exception: %s" % (e)
        if options.verbose:
            traceback.print_exc()
        else:
            print "You can run this again with the '--verbose' flag to see more information"
        print "If the problem persists, please email the output to: %s" % (_CONTACT)
        return 1
    
    return 0

if __name__ == "__main__":
    sys.exit(main())