diff options
-rwxr-xr-x | mpm/tools/mpm_shell.py | 139 |
1 files changed, 86 insertions, 53 deletions
diff --git a/mpm/tools/mpm_shell.py b/mpm/tools/mpm_shell.py index e480a4857..ed2998809 100755 --- a/mpm/tools/mpm_shell.py +++ b/mpm/tools/mpm_shell.py @@ -12,10 +12,8 @@ from __future__ import print_function import cmd import time import argparse -import threading +import multiprocessing from importlib import import_module -from mprpc import RPCClient -from mprpc.exceptions import RPCError try: from usrp_mpm.mpmtypes import MPM_RPC_PORT @@ -69,54 +67,97 @@ class MPMClaimer(object): """ Holds a claim. """ - def __init__(self, host, port, disc_callback): + def __init__(self, host, port): self.token = None - self._exit_loop = False - self._disc_callback = disc_callback - self._claim_loop = threading.Thread( + self.hijacked = False + self._cmd_q = multiprocessing.Queue() + self._token_q = multiprocessing.Queue() + self._claim_loop = multiprocessing.Process( target=self.claim_loop, name="Claimer Loop", - args=(host, port, self._disc_callback) + args=(host, port, self._cmd_q, self._token_q) ) self._claim_loop.start() - def claim_loop(self, host, port, disc_callback): + def claim_loop(self, host, port, cmd_q, token_q): """ Run a claim loop """ + from mprpc import RPCClient + from mprpc.exceptions import RPCError + cmd = None + token = None + exit_loop = False client = RPCClient(host, port, pack_params={'use_bin_type': True}) - self.token = client.call('claim', 'MPM Shell') try: - while not self._exit_loop: - client.call('reclaim', self.token) + while not exit_loop: + if token and not cmd: + client.call('reclaim', token) + elif cmd == 'claim': + if not token: + token = client.call('claim', 'MPM Shell') + else: + print("Already have claim") + token_q.put(token) + elif cmd == 'unclaim': + if token: + client.call('unclaim', token) + token = None + token_q.put(None) + elif cmd == 'exit': + if token: + client.call('unclaim', token) + token = None + token_q.put(None) + exit_loop = True time.sleep(1) - client.call('unclaim', self.token) + cmd = None + if not cmd_q.empty(): + cmd = cmd_q.get(False) except RPCError as ex: print("Unexpected RPC error in claimer loop!") print(str(ex)) - disc_callback() - self.token = None - def unclaim(self): + def exit(self): """ Unclaim device and exit claim loop. """ - self._exit_loop = True + self.unclaim() + self._cmd_q.put('exit') self._claim_loop.join() -class MPMHijacker(object): - """ - Looks like a claimer object, but doesn't actually claim. - """ - def __init__(self, token): - self.token = token - def unclaim(self): """ - Unclaim device and exit claim loop. + Unclaim device. + """ + if not self.hijacked: + self._cmd_q.put('unclaim') + else: + self.hijacked = False + self.token = None + + def claim(self): + """ + Claim device. + """ + self._cmd_q.put('claim') + self.token = self._token_q.get(True, 5.0) + + def get_token(self): + """ + Get current token (if any) """ - pass + if not self._token_q.empty(): + self.token = self._token_q.get(False) + return self.token + def hijack(self, token): + if self.token: + print("Already have token") + return + else: + self.token = token + self.hijacked = True class MPMShell(cmd.Cmd): """ @@ -127,10 +168,10 @@ class MPMShell(cmd.Cmd): self.prompt = "> " self.client = None self.remote_methods = [] - self._claimer = None self._host = host self._port = port self._device_info = None + self._claimer = MPMClaimer(self._host, self._port) if host is not None: self.connect(host, port) if claim: @@ -156,15 +197,16 @@ class MPMShell(cmd.Cmd): """ Template function to create new RPC shell commands """ + from mprpc.exceptions import RPCError if requires_token and \ - (self._claimer is None or self._claimer.token is None): + (self._claimer is None or self._claimer.get_token() is None): print("Cannot execute `{}' -- no claim available!") return try: if args or requires_token: expanded_args = self.expand_args(args) if requires_token: - expanded_args.insert(0, self._claimer.token) + expanded_args.insert(0, self._claimer.get_token()) response = self.client.call(command, *expanded_args) else: response = self.client.call(command) @@ -214,6 +256,8 @@ class MPMShell(cmd.Cmd): """ Launch a connection. """ + from mprpc import RPCClient + from mprpc.exceptions import RPCError print("Attempting to connect to {host}:{port}...".format( host=host, port=port )) @@ -239,9 +283,10 @@ class MPMShell(cmd.Cmd): """ Clean up after a connection was closed. """ + from mprpc.exceptions import RPCError self._device_info = None if self._claimer is not None: - self._claimer.unclaim() + self._claimer.exit() if self.client: try: self.client.close() @@ -257,37 +302,24 @@ class MPMShell(cmd.Cmd): def claim(self): " Initialize claim " - assert self.client is not None - if self._claimer is not None: - print("Claimer already active.") - return True print("Claiming device...") - self._claimer = MPMClaimer(self._host, self._port, self.unclaim_hook) + self._claimer.claim() return True def hijack(self, token): " Hijack running session " - assert self.client is not None - if self._claimer is not None: + if self._claimer.hijacked: print("Claimer already active. Can't hijack.") return False print("Hijacking device...") - self._claimer = MPMHijacker(token) + self._claimer.hijack(token) return True def unclaim(self): """ unclaim """ - if self._claimer is not None: - self._claimer.unclaim() - self._claimer = None - - def unclaim_hook(self): - """ - Hook - """ - pass + self._claimer.unclaim() def update_prompt(self): """ @@ -296,12 +328,13 @@ class MPMShell(cmd.Cmd): if self._device_info is None: self.prompt = '> ' else: - if self._claimer is None: + token = self._claimer.get_token() + if token is None: claim_status = '' - elif isinstance(self._claimer, MPMClaimer): - claim_status = ' [C]' - elif isinstance(self._claimer, MPMHijacker): + elif self._claimer.hijacked: claim_status = ' [H]' + else: + claim_status = ' [C]' self.prompt = '{dev_id}{claim_status}> '.format( dev_id=self._device_info.get( 'name', self._device_info.get('serial', '?') @@ -313,8 +346,8 @@ class MPMShell(cmd.Cmd): """ Takes a string and returns a list """ - if self._claimer is not None and self._claimer.token is not None: - args = args.replace('$T', str(self._claimer.token)) + if self._claimer is not None and self._claimer.get_token() is not None: + args = args.replace('$T', str(self._claimer.get_token())) eval_preamble = '=' args = args.strip() if args.startswith(eval_preamble): |