#!/usr/bin/env python3
#
# Copyright 2017 Ettus Research, a National Instruments Company
#
# SPDX-License-Identifier: GPL-3.0-or-later
#
"""
RPC shell to debug USRP MPM capable devices
"""

from __future__ import print_function
import cmd
import time
import argparse
import multiprocessing
from importlib import import_module

try:
    from usrp_mpm.mpmtypes import MPM_RPC_PORT
except ImportError:
    MPM_RPC_PORT = None

DEFAULT_MPM_RPC_PORT = 49601
if MPM_RPC_PORT is None:
    MPM_RPC_PORT = DEFAULT_MPM_RPC_PORT
if MPM_RPC_PORT != DEFAULT_MPM_RPC_PORT:
    print("Warning: Default encoded MPM RPC port does not match that in MPM.")


def parse_args():
    """
    Parse command line args.
    """
    parser = argparse.ArgumentParser(
        description="MPM Shell",
    )
    parser.add_argument(
        'host',
        help="Specify host to connect to.", default=None,
    )
    parser.add_argument(
        '-p', '--port', type=int,
        help="Specify port to connect to.", default=MPM_RPC_PORT,
    )
    parser.add_argument(
        '-c', '--claim',
        action='store_true',
        help="Claim device after connecting."
    )
    parser.add_argument(
        '-j', '--hijack', type=str,
        help="Hijack running session (excludes --claim)."
    )
    return parser.parse_args()


def split_args(args, *default_args):
    " Returns an array of args, space-separated "
    args = args.split()
    return [
        arg_val if arg_idx < len(args) else default_args[arg_idx]
        for arg_idx, arg_val in enumerate(args)
    ]


class MPMClaimer(object):
    """
    Holds a claim.
    """
    def __init__(self, host, port):
        self.token = None
        self.hijacked = False
        self._cmd_q = multiprocessing.Queue()
        self._token_q = multiprocessing.Queue()
        self._claim_loop = multiprocessing.Process(
            target=self.claim_loop,
            name="Claimer Loop",
            args=(host, port, self._cmd_q, self._token_q)
        )
        self._claim_loop.start()

    def claim_loop(self, host, port, cmd_q, token_q):
        """
        Run a claim loop
        """
        from mprpc import RPCClient
        from mprpc.exceptions import RPCError
        cmd = None
        token = None
        exit_loop = False
        client = RPCClient(host, port, pack_params={'use_bin_type': True})
        try:
            while not exit_loop:
                if token and not cmd:
                    client.call('reclaim', token)
                elif cmd == 'claim':
                    if not token:
                        token = client.call('claim', 'MPM Shell')
                    else:
                        print("Already have claim")
                    token_q.put(token)
                elif cmd == 'unclaim':
                    if token:
                        client.call('unclaim', token)
                    token = None
                    token_q.put(None)
                elif cmd == 'exit':
                    if token:
                        client.call('unclaim', token)
                    token = None
                    token_q.put(None)
                    exit_loop = True
                time.sleep(1)
                cmd = None
                if not cmd_q.empty():
                    cmd = cmd_q.get(False)
        except RPCError as ex:
            print("Unexpected RPC error in claimer loop!")
            print(str(ex))

    def exit(self):
        """
        Unclaim device and exit claim loop.
        """
        self.unclaim()
        self._cmd_q.put('exit')
        self._claim_loop.join()

    def unclaim(self):
        """
        Unclaim device.
        """
        if not self.hijacked:
            self._cmd_q.put('unclaim')
        else:
            self.hijacked = False
        self.token = None

    def claim(self):
        """
        Claim device.
        """
        self._cmd_q.put('claim')
        self.token = self._token_q.get(True, 5.0)

    def get_token(self):
        """
        Get current token (if any)
        """
        if not self._token_q.empty():
            self.token = self._token_q.get(False)
        return self.token

    def hijack(self, token):
        if self.token:
            print("Already have token")
            return
        else:
            self.token = token
        self.hijacked = True

class MPMShell(cmd.Cmd):
    """
    RPC Shell class. See cmd module.
    """
    def __init__(self, host, port, claim, hijack):
        cmd.Cmd.__init__(self)
        self.prompt = "> "
        self.client = None
        self.remote_methods = []
        self._host = host
        self._port = port
        self._device_info = None
        self._claimer = MPMClaimer(self._host, self._port)
        if host is not None:
            self.connect(host, port)
            if claim:
                self.claim()
            elif hijack:
                self.hijack(hijack)
        self.update_prompt()

    def _add_command(self, command, docs, requires_token=False):
        """
        Add a command to the current session
        """
        cmd_name = 'do_' + command
        if not hasattr(self, cmd_name):
            new_command = lambda args: self.rpc_template(
                str(command), requires_token, args
            )
            new_command.__doc__ = docs
            setattr(self, cmd_name, new_command)
            self.remote_methods.append(command)

    def rpc_template(self, command, requires_token, args=None):
        """
        Template function to create new RPC shell commands
        """
        from mprpc.exceptions import RPCError
        if requires_token and \
                (self._claimer is None or self._claimer.get_token() is None):
            print("Cannot execute `{}' -- no claim available!")
            return
        try:
            if args or requires_token:
                expanded_args = self.expand_args(args)
                if requires_token:
                    expanded_args.insert(0, self._claimer.get_token())
                response = self.client.call(command, *expanded_args)
            else:
                response = self.client.call(command)
        except RPCError as ex:
            print("RPC Command failed!")
            print("Error: {}".format(ex))
            return
        except Exception as ex:
            print("Unexpected exception!")
            print("Error: {}".format(ex))
            return
        if isinstance(response, bool):
            if response:
                print("Command executed successfully!")
            else:
                print("Command failed!")
        else:
            print("==> " + str(response))
        return response

    def get_names(self):
        " We need this for tab completion. "
        return dir(self)

    ###########################################################################
    # Cmd module specific
    ###########################################################################
    def run(self):
        " Go, go, go! "
        try:
            self.cmdloop()
        except KeyboardInterrupt:
            self.do_disconnect(None)
            exit(0)

    def postcmd(self, stop, line):
        """
        Is run after every command executes. Does:
        - Update prompt
        """
        self.update_prompt()

    ###########################################################################
    # Internal methods
    ###########################################################################
    def connect(self, host, port):
        """
        Launch a connection.
        """
        from mprpc import RPCClient
        from mprpc.exceptions import RPCError
        print("Attempting to connect to {host}:{port}...".format(
            host=host, port=port
        ))
        try:
            self.client = RPCClient(host, port, pack_params={'use_bin_type': True})
            print("Connection successful.")
        except Exception as ex:
            print("Connection refused")
            print("Error: {}".format(ex))
            return False
        self._host = host
        self._port = port
        print("Getting methods...")
        methods = self.client.call('list_methods')
        for method in methods:
            self._add_command(*method)
        print("Added {} methods.".format(len(methods)))
        print("Quering device info...")
        self._device_info = self.client.call('get_device_info')
        return True

    def disconnect(self):
        """
        Clean up after a connection was closed.
        """
        from mprpc.exceptions import RPCError
        self._device_info = None
        if self._claimer is not None:
            self._claimer.exit()
        if self.client:
            try:
                self.client.close()
            except RPCError as ex:
                print("Error while closing the connection")
                print("Error: {}".format(ex))
        for method in self.remote_methods:
            delattr(self, "do_" + method)
        self.remote_methods = []
        self.client = None
        self._host = None
        self._port = None

    def claim(self):
        " Initialize claim "
        print("Claiming device...")
        self._claimer.claim()
        return True

    def hijack(self, token):
        " Hijack running session "
        if self._claimer.hijacked:
            print("Claimer already active. Can't hijack.")
            return False
        print("Hijacking device...")
        self._claimer.hijack(token)
        return True

    def unclaim(self):
        """
        unclaim
        """
        self._claimer.unclaim()

    def update_prompt(self):
        """
        Update prompt
        """
        if self._device_info is None:
            self.prompt = '> '
        else:
            token = self._claimer.get_token()
            if token is None:
                claim_status = ''
            elif self._claimer.hijacked:
                claim_status = ' [H]'
            else:
                claim_status = ' [C]'
            self.prompt = '{dev_id}{claim_status}> '.format(
                dev_id=self._device_info.get(
                    'name', self._device_info.get('serial', '?')
                ),
                claim_status=claim_status,
            )

    def expand_args(self, args):
        """
        Takes a string and returns a list
        """
        if self._claimer is not None and self._claimer.get_token() is not None:
            args = args.replace('$T', str(self._claimer.get_token()))
        eval_preamble = '='
        args = args.strip()
        if args.startswith(eval_preamble):
            parsed_args = eval(args.lstrip(eval_preamble))
            if not isinstance(parsed_args, list):
                parsed_args = [parsed_args]
        else:
            parsed_args = []
            for arg in args.split():
                try:
                    parsed_args.append(int(arg, 0))
                    continue
                except ValueError:
                    pass
                try:
                    parsed_args.append(float(arg))
                    continue
                except ValueError:
                    pass
                parsed_args.append(arg)
        return parsed_args

    ###########################################################################
    # Predefined commands
    ###########################################################################
    def do_connect(self, args):
        """
        Connect to a remote MPM server. See connect()
        """
        host, port = split_args(args, 'localhost', MPM_RPC_PORT)
        port = int(port)
        self.connect(host, port)

    def do_claim(self, _):
        """
        Spawn a claim loop
        """
        self.claim()

    def do_hijack(self, token):
        """
        Hijack a running session
        """
        self.hijack(token)

    def do_unclaim(self, _):
        """
        unclaim
        """
        self.unclaim()

    def do_disconnect(self, _):
        """
        disconnect from the RPC server
        """
        self.disconnect()

    def do_import(self, args):
        """import a python module into the global namespace"""
        globals()[args] = import_module(args)

    def do_EOF(self, _):
        " When catching EOF, exit the program. "
        print("Exiting...")
        self.disconnect()
        exit(0)

def main():
    " Go, go, go! "
    args = parse_args()
    my_shell = MPMShell(args.host, args.port, args.claim, args.hijack)

    try:
        return my_shell.run()
    except KeyboardInterrupt:
        my_shell.disconnect()
    except Exception as ex:
        print("Uncaught exception: " + str(ex))
        my_shell.disconnect()
    return True

if __name__ == "__main__":
    exit(not main())