aboutsummaryrefslogtreecommitdiffstats
path: root/mpm/python/usrp_mpm
diff options
context:
space:
mode:
Diffstat (limited to 'mpm/python/usrp_mpm')
-rw-r--r--mpm/python/usrp_mpm/rpc_server.py69
1 files changed, 47 insertions, 22 deletions
diff --git a/mpm/python/usrp_mpm/rpc_server.py b/mpm/python/usrp_mpm/rpc_server.py
index ce82393ab..7b8d1edba 100644
--- a/mpm/python/usrp_mpm/rpc_server.py
+++ b/mpm/python/usrp_mpm/rpc_server.py
@@ -18,6 +18,9 @@
Implemented RPC Servers
"""
from __future__ import print_function
+from random import choice
+from string import ascii_letters, digits
+from multiprocessing import Process
from gevent.server import StreamServer
from gevent.pool import Pool
from gevent import signal
@@ -26,11 +29,19 @@ from gevent import Greenlet
from gevent import monkey
monkey.patch_all()
from mprpc import RPCServer
-from random import choice
-from string import ascii_letters, digits
-from multiprocessing import Process
from .mpmlog import get_main_logger
+TOKEN_LEN = 16 # Length of the token string
+
+def no_claim(func):
+ " Decorator for functions that require no token check "
+ func._notok = True
+ return func
+
+def no_rpc(func):
+ " Decorator for functions that should not be exposed via RPC "
+ func._norpc = True
+ return func
class MPMServer(RPCServer):
"""
@@ -56,36 +67,48 @@ class MPMServer(RPCServer):
def _update_component_commands(self, component, namespace, storage):
"""
- Detect available methods for an object and add them to the RPC server
- """
- for method in (m for m in dir(component)
- if not m.startswith('_') and callable(getattr(component, m))):
- if method.startswith('safe_'):
- command_name = namespace + method.lstrip('safe_')
- self._add_safe_command(getattr(component, method), command_name)
+ Detect available methods for an object and add them to the RPC server.
+
+ We skip all private methods, and all methods that use the @no_rpc
+ decorator.
+ """
+ for method_name in (
+ m for m in dir(component)
+ if not m.startswith('_') \
+ and callable(getattr(component, m)) \
+ and not getattr(getattr(component, m), '_norpc', False)
+ ):
+ new_rpc_method = getattr(component, method_name)
+ command_name = namespace + method_name
+ if getattr(new_rpc_method, '_notok', False):
+ self._add_safe_command(new_rpc_method, command_name)
else:
- command_name = namespace + method
- self._add_command(getattr(component, method), command_name)
+ self._add_claimed_command(new_rpc_method, command_name)
getattr(self, storage).append(command_name)
- def _add_command(self, function, command):
+ def _add_claimed_command(self, function, command):
"""
Adds a method with the name command to the RPC server
- This command will require an acquired claim on the device
+ This command will require an acquired claim on the device, and a valid
+ token needs to be passed in for it to not fail.
+
+ If the method does not require a token, use _add_safe_command().
"""
self.log.trace("adding command %s pointing to %s", command, function)
- def new_function(token, *args):
- if token[:256] != self._state.claim_token.value:
+ def new_claimed_function(token, *args):
+ " Define a function that requires a claim token check "
+ if token[:TOKEN_LEN] != self._state.claim_token.value:
return False
return function(*args)
- new_function.__doc__ = function.__doc__
- setattr(self, command, new_function)
+ new_claimed_function.__doc__ = function.__doc__
+ setattr(self, command, new_claimed_function)
def _add_safe_command(self, function, command):
"""
- Add a safe method which does not require a claim on the
- device
+ Add a safe method which does not require a claim on the device.
+ If the method should only be called by claimers, use
+ _add_claimed_command().
"""
self.log.trace("adding safe command %s pointing to %s", command, function)
setattr(self, command, function)
@@ -117,7 +140,9 @@ class MPMServer(RPCServer):
return ""
self.log.debug("claiming from: %s", self.client_host)
self.periph_manager.claimed = True
- self._state.claim_token.value = ''.join(choice(ascii_letters + digits) for _ in range(256))
+ self._state.claim_token.value = ''.join(
+ choice(ascii_letters + digits) for _ in range(TOKEN_LEN)
+ )
self._state.claim_status.value = True
self._state.lock.release()
self.sender_id = sender_id
@@ -133,7 +158,7 @@ class MPMServer(RPCServer):
"""
self._state.lock.acquire()
if self._state.claim_status.value:
- if self._state.claim_token.value == token[:256]:
+ if self._state.claim_token.value == token[:TOKEN_LEN]:
self._state.lock.release()
self.log.debug("reclaimed from: %s", self.client_host)
self._reset_timer()