#!/usr/bin/env python
#
# Copyright 2009 Free Software Foundation, Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

import termios
import tty
import os
import sys
import threading
import Queue
from optparse import OptionParser
import time

import sbf

GDB_ESCAPE = chr(0x7d)

# Indexes for termios list.
IFLAG = 0
OFLAG = 1
CFLAG = 2
LFLAG = 3
ISPEED = 4
OSPEED = 5
CC = 6

class terminal(object):
    def __init__(self, device, speed_bits):
        fd = os.open(device, os.O_RDWR)
        if not os.isatty(fd):
            raise ValueError(device + " is not a tty")

        self.read_file = os.fdopen(fd, "rb", 0)
        self.write_file = os.fdopen(os.dup(fd), "wb", 0)
        self.old_attrs = termios.tcgetattr(self.write_file.fileno())
        #print "old_attrs: ", self.old_attrs
        attrs = list(self.old_attrs)    # copy of attributes
        attrs[ISPEED] = speed_bits      # set input and output speed
        attrs[OSPEED] = speed_bits
        termios.tcsetattr(self.write_file.fileno(), termios.TCSAFLUSH, attrs)
        tty.setraw(self.write_file.fileno())     # enable raw mode
        attrs = termios.tcgetattr(self.write_file.fileno())
        attrs[CC][termios.VMIN] = 1     # minimim of 1 char
        attrs[CC][termios.VTIME] = 1    # wait no longer than 1/10
        termios.tcsetattr(self.write_file.fileno(), termios.TCSAFLUSH, attrs)

    def __del__(self):
        termios.tcsetattr(self.write_file.fileno(), termios.TCSAFLUSH, self.old_attrs)
        self.read_file.close()
        self.write_file.close()

    def read(self, n):
        """Read at most n bytes from tty"""
        return self.read_file.read(n)

    def write(self, str):
        """Write str to tty."""
        return self.write_file.write(str)


def hexnibble(i):
    return "0123456789abcdef"[i & 0xf]

def build_pkt(payload):
    s = ['$']
    checksum = 0

    for p in payload:
        if p in ('$', '#', GDB_ESCAPE):
            s.append(GDB_ESCAPE)
        s.append(p)
        checksum += ord(p)

    checksum &= 0xff
    s.append('#')
    s.append(hexnibble(checksum >> 4))
    s.append(hexnibble(checksum))

    return ''.join(s)


def build_memory_read_pkt(addr, len):
    return build_pkt(('m%x,%x' % (addr, len)))

def build_memory_write_hex_pkt(addr, s):
    hexdata = ''.join(["%02x" % (ord(c),) for c in s])
    return build_pkt(('M%x,%x:' % (addr, len(s))) + hexdata)

def build_memory_write_pkt(addr, s):
    return build_pkt(('X%x,%x:' % (addr, len(s))) + s)

def build_goto_pkt(addr):
    return build_pkt(('c%x' % (addr,)))


def get_packet(f):
    """Return a valid packet, or None on EOF or timeout"""
    LOOKING_FOR_DOLLAR = 0
    LOOKING_FOR_HASH = 1
    CSUM1 = 2
    CSUM2 = 3

    fd = f.fileno()
    
    state = LOOKING_FOR_DOLLAR
    buf = []

    while True:
        ch = os.read(fd, 1)
        sys.stdout.write(ch)
        if len(ch) == 0:
            print("Returning None")
            return(None)
        
        if state == LOOKING_FOR_DOLLAR:
            if ch == '$':
                buf = []
                state = LOOKING_FOR_HASH
            elif ch == '#':
                state = LOOKING_FOR_DOLLAR

        elif state == LOOKING_FOR_HASH:
            if ch == '$':
                state = LOOKING_FOR_DOLLAR
            elif ch == '#':
                state = CSUM1
            else:
                if ch == GDB_ESCAPE:
                    ch = getc()
                buf.append(ch)

        elif state == CSUM1:
            chksum1 = ch
            state = CSUM2

        elif state == CSUM2:
            chksum2 = ch
            r = ''.join(buf)
            if chksum1 == '.' and chksum2 == '.':
                return r

            expected_checksum = int(chksum1 + chksum2, 16)
            checksum = 0
            for c in buf:
                checksum += ord(c)

            checksum &= 0xff
            if checksum == expected_checksum:
                return r

            state = LOOKING_FOR_DOLLAR

        else:
            raise ValueError( "Invalid state")


class packet_reader_thread(threading.Thread):
    def __init__(self, tty_in, q):
        threading.Thread.__init__(self)
        self.setDaemon(1)
        self.tty_in = tty_in
        self.q = q
        self._keep_running = True
        self.start()

    def run(self):
        while self._keep_running == True:
            p = get_packet(self.tty_in)
            if p is not None:
                self.q.put(('pkt', p))


def _make_tr_table():
    table = []
    for c in range(256):
        if c < ord(' ') or c > ord('~'):
            table.append('.')
        else:
            table.append(chr(c))
    return ''.join(table)
        

class controller(object):
    def __init__(self, tty):
        self.tty = tty
        self.q = Queue.Queue(0)
        self.timers = {}
        self.next_tid = 1
        self.current_tid = 0
        self.ntimeouts = 0
        self.packet_reader = packet_reader_thread(tty.read_file, self.q)
        self.state = None
        self.debug = False
        self.tt = _make_tr_table()

        self.done = False
        self.addr = None
        self.bits = None

    def shutdown(self):
        self.packet_reader._keep_running = False

    def start_timeout(self, timeout_in_secs):
        def callback(tid):
            if self.timers.has_key(tid):
                del self.timers[tid]
            self.q.put(('timeout', tid))
        self.next_tid += 1
        tid = self.next_tid
        timer = threading.Timer(timeout_in_secs, callback, (tid,))
        self.timers[tid] = timer
        timer.start()
        return tid

    def cancel_timeout(self, tid):
        if self.timers.has_key(tid):
            self.timers[tid].cancel()
            del self.timers[tid]

    def send_packet(self, pkt):
        if self.debug:
            if len(pkt) > 64:
                s = pkt[0:64] + '...'
            else:
                s = pkt
            sys.stdout.write('-> ' +  s.translate(self.tt) + '\n')
        self.tty.write(pkt);

    def send_packet_start_timeout(self, pkt, secs):
        self.send_packet(pkt)
        self.current_tid = self.start_timeout(secs)
        
    def upload_code(self, sbf):
        MAX_PIECE = 512                 # biggest piece to send
        MWRITE_TIMEOUT = 0.1

        IDLE = 0
        WAIT_FOR_ACK = 1
        UPLOAD_DONE = 2
        DONE = 3
        FAILED = 4

        self.done = False
        it = sbf.iterator(MAX_PIECE)
        entry_addr = sbf.entry

        def get_next_bits():
            try:
                (self.addr, self.bits) = it.next()
            except StopIteration:
                self.done = True
            
        def is_done():
            return self.done
        
        def send_piece():
            pkt = build_memory_write_pkt(self.addr, self.bits)
            #pkt = build_memory_write_hex_pkt(self.addr, self.bits)
            self.send_packet_start_timeout(pkt, MWRITE_TIMEOUT)
            state = WAIT_FOR_ACK

        def advance():
            get_next_bits()
            if is_done():
                self.state = DONE
                self.send_packet(build_goto_pkt(entry_addr))
                
            else:
                self.ntimeouts = 0
                send_piece()

        get_next_bits()
        if is_done():                   # empty file
            return True
        
        send_piece()                    # initial transition
        
        while 1:
            (event, value) = self.q.get()

            if event == 'timeout' and value == self.current_tid:
                self.ntimeouts += 1
                if self.ntimeouts >= 5:
                    return False        # say we failed
                send_piece()            # resend
                

            elif event == 'pkt':
                if value == 'OK':
                    self.cancel_timeout(self.current_tid)
                    advance()
                    if self.state == DONE:
                        return True
                else:
                    print("Error returned from firmware: " + value)
                    return False

            else:
                print("Unknown event:", (event, value))

def main():
    usage="%prog: [options] filename"
    parser = OptionParser(usage=usage)
    parser.add_option("-t", "--tty", type="string", default="/dev/ttyS0",
                      help="select serial port [default=%default]")

    (options, args) = parser.parse_args()
    if len(args) != 1:
        parser.print_help()
        raise SystemExit(1)


    filename = args[0]
    f = open(filename, "rb")
    try:
        # Try to open the file as an SBF file
        sbf_header = sbf.read_sbf(f)
    except:
        # If that fails, build an SBF from the binary, assuming default
        # load address and entry point
        f.seek(0)
        t = f.read()
        if t.startswith('\177ELF'):
            sys.stderr.write("Can't load an ELF file.  Please use an SBF file instead.\n")
            raise SystemExit( 1)
        sbf_header = sbf.header(0x8000, [sbf.sec_desc(0x8000, t)])
        

    tty = terminal(options.tty, termios.B115200)
    ctrl = controller(tty)
    ok = ctrl.upload_code(sbf_header)
    
    if ok:
        print("OK")
        try:
            raw_input("Press Enter to exit: ")
        except KeyboardInterrupt:
            pass
        ctrl.shutdown()
        time.sleep(0.2)
        
    else:
        print("Upload failed")
        ctrl.shutdown()
    
    
    
if __name__ == "__main__":
    main()