diff options
-rw-r--r-- | mpm/python/usrp_mpm/rpc_server.py | 69 |
1 files changed, 47 insertions, 22 deletions
diff --git a/mpm/python/usrp_mpm/rpc_server.py b/mpm/python/usrp_mpm/rpc_server.py index ce82393ab..7b8d1edba 100644 --- a/mpm/python/usrp_mpm/rpc_server.py +++ b/mpm/python/usrp_mpm/rpc_server.py @@ -18,6 +18,9 @@ Implemented RPC Servers """ from __future__ import print_function +from random import choice +from string import ascii_letters, digits +from multiprocessing import Process from gevent.server import StreamServer from gevent.pool import Pool from gevent import signal @@ -26,11 +29,19 @@ from gevent import Greenlet from gevent import monkey monkey.patch_all() from mprpc import RPCServer -from random import choice -from string import ascii_letters, digits -from multiprocessing import Process from .mpmlog import get_main_logger +TOKEN_LEN = 16 # Length of the token string + +def no_claim(func): + " Decorator for functions that require no token check " + func._notok = True + return func + +def no_rpc(func): + " Decorator for functions that should not be exposed via RPC " + func._norpc = True + return func class MPMServer(RPCServer): """ @@ -56,36 +67,48 @@ class MPMServer(RPCServer): def _update_component_commands(self, component, namespace, storage): """ - Detect available methods for an object and add them to the RPC server - """ - for method in (m for m in dir(component) - if not m.startswith('_') and callable(getattr(component, m))): - if method.startswith('safe_'): - command_name = namespace + method.lstrip('safe_') - self._add_safe_command(getattr(component, method), command_name) + Detect available methods for an object and add them to the RPC server. + + We skip all private methods, and all methods that use the @no_rpc + decorator. + """ + for method_name in ( + m for m in dir(component) + if not m.startswith('_') \ + and callable(getattr(component, m)) \ + and not getattr(getattr(component, m), '_norpc', False) + ): + new_rpc_method = getattr(component, method_name) + command_name = namespace + method_name + if getattr(new_rpc_method, '_notok', False): + self._add_safe_command(new_rpc_method, command_name) else: - command_name = namespace + method - self._add_command(getattr(component, method), command_name) + self._add_claimed_command(new_rpc_method, command_name) getattr(self, storage).append(command_name) - def _add_command(self, function, command): + def _add_claimed_command(self, function, command): """ Adds a method with the name command to the RPC server - This command will require an acquired claim on the device + This command will require an acquired claim on the device, and a valid + token needs to be passed in for it to not fail. + + If the method does not require a token, use _add_safe_command(). """ self.log.trace("adding command %s pointing to %s", command, function) - def new_function(token, *args): - if token[:256] != self._state.claim_token.value: + def new_claimed_function(token, *args): + " Define a function that requires a claim token check " + if token[:TOKEN_LEN] != self._state.claim_token.value: return False return function(*args) - new_function.__doc__ = function.__doc__ - setattr(self, command, new_function) + new_claimed_function.__doc__ = function.__doc__ + setattr(self, command, new_claimed_function) def _add_safe_command(self, function, command): """ - Add a safe method which does not require a claim on the - device + Add a safe method which does not require a claim on the device. + If the method should only be called by claimers, use + _add_claimed_command(). """ self.log.trace("adding safe command %s pointing to %s", command, function) setattr(self, command, function) @@ -117,7 +140,9 @@ class MPMServer(RPCServer): return "" self.log.debug("claiming from: %s", self.client_host) self.periph_manager.claimed = True - self._state.claim_token.value = ''.join(choice(ascii_letters + digits) for _ in range(256)) + self._state.claim_token.value = ''.join( + choice(ascii_letters + digits) for _ in range(TOKEN_LEN) + ) self._state.claim_status.value = True self._state.lock.release() self.sender_id = sender_id @@ -133,7 +158,7 @@ class MPMServer(RPCServer): """ self._state.lock.acquire() if self._state.claim_status.value: - if self._state.claim_token.value == token[:256]: + if self._state.claim_token.value == token[:TOKEN_LEN]: self._state.lock.release() self.log.debug("reclaimed from: %s", self.client_host) self._reset_timer() |