diff options
Diffstat (limited to '.ci/utils/tftp.py')
-rw-r--r-- | .ci/utils/tftp.py | 98 |
1 files changed, 98 insertions, 0 deletions
diff --git a/.ci/utils/tftp.py b/.ci/utils/tftp.py new file mode 100644 index 000000000..44291cd93 --- /dev/null +++ b/.ci/utils/tftp.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 + +import asyncio +import py3tftp.protocols +import pyroute2 +import socket +import threading +from pathlib import Path + + +class FileReaderSingle: + def __init__(self, path, fname_req, chunk_size=0): + self.path = path + # TODO: Should check fname_req against actual name + self.chunk_size = chunk_size + self._f = None + self._f = open(self.path, 'rb') + self.finished = False + + def file_size(self): + return self.path.stat().st_size + + def read_chunk(self, size=None): + size = size or self.chunk_size + if self.finished: + return b'' + + data = self._f.read(size) + if not data or (size > 0 and len(data) < size): + self._f.close() + self.finished = True + + return data + + def __del__(self): + if self._f and not self._f.closed: + self._f.close() + + +class TFTPServerSingle(py3tftp.protocols.BaseTFTPServerProtocol): + def __init__(self, path, host_interface, loop, extra_opts): + super().__init__(host_interface, loop, extra_opts) + self.path = path + + def select_protocol(self, packet): + if packet.is_rrq(): + return py3tftp.protocols.RRQProtocol + raise py3tftp.protocols.ProtocolException("Unhandled protocol") + + def select_file_handler(self, packet): + if packet.is_rrq(): + return lambda filename, opts: FileReaderSingle(self.path, filename, opts) + + +class TFTPServer: + """ + Simple TFTP server, meant to be short-lived and capable of serving a single + file only + """ + def __init__(self, filename, remote_ip, port=None): + self.path = Path(filename).absolute() + assert self.path.exists() + assert self.path.is_file() + + self.filename = self.path.name + + if port == None: + with socket.socket() as s: + s.bind(('', 0)) + self.port = s.getsockname()[1] + else: + self.port = port + + with pyroute2.IPRoute() as ipr: + r = ipr.route('get', dst=remote_ip) + for attr in r[0]['attrs']: + if attr[0] == 'RTA_PREFSRC': + self.ip = attr[1] + + def __enter__(self): + self.loop = asyncio.new_event_loop() + listen = self.loop.create_datagram_endpoint( + lambda: TFTPServerSingle(self.path, self.ip, self.loop, {}), + local_addr=(self.ip, self.port)) + + def start_loop(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + self.transport, protocol = self.loop.run_until_complete(listen) + self.thread = threading.Thread(target=start_loop, args=(self.loop,)) + self.thread.start() + return self + + def __exit__(self, type, value, exc): + self.transport.close() + self.loop.call_soon_threadsafe(self.loop.stop) + self.thread.join() |