aboutsummaryrefslogtreecommitdiffstats
path: root/.ci/utils/tftp.py
blob: 44291cd93856f0b71072c562aca87d20b6605d44 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python3

import asyncio
import py3tftp.protocols
import pyroute2
import socket
import threading
from pathlib import Path


class FileReaderSingle:
    def __init__(self, path, fname_req, chunk_size=0):
        self.path = path
        # TODO: Should check fname_req against actual name
        self.chunk_size = chunk_size
        self._f = None
        self._f = open(self.path, 'rb')
        self.finished = False

    def file_size(self):
        return self.path.stat().st_size

    def read_chunk(self, size=None):
        size = size or self.chunk_size
        if self.finished:
            return b''

        data = self._f.read(size)
        if not data or (size > 0 and len(data) < size):
            self._f.close()
            self.finished = True

        return data

    def __del__(self):
        if self._f and not self._f.closed:
            self._f.close()


class TFTPServerSingle(py3tftp.protocols.BaseTFTPServerProtocol):
    def __init__(self, path, host_interface, loop, extra_opts):
        super().__init__(host_interface, loop, extra_opts)
        self.path = path

    def select_protocol(self, packet):
        if packet.is_rrq():
            return py3tftp.protocols.RRQProtocol
        raise py3tftp.protocols.ProtocolException("Unhandled protocol")

    def select_file_handler(self, packet):
        if packet.is_rrq():
            return lambda filename, opts: FileReaderSingle(self.path, filename, opts)


class TFTPServer:
    """
    Simple TFTP server, meant to be short-lived and capable of serving a single
    file only
    """
    def __init__(self, filename, remote_ip, port=None):
        self.path = Path(filename).absolute()
        assert self.path.exists()
        assert self.path.is_file()

        self.filename = self.path.name

        if port == None:
            with socket.socket() as s:
                s.bind(('', 0))
                self.port = s.getsockname()[1]
        else:
            self.port = port

        with pyroute2.IPRoute() as ipr:
            r = ipr.route('get', dst=remote_ip)
            for attr in r[0]['attrs']:
                if attr[0] == 'RTA_PREFSRC':
                    self.ip = attr[1]

    def __enter__(self):
        self.loop = asyncio.new_event_loop()
        listen = self.loop.create_datagram_endpoint(
                lambda: TFTPServerSingle(self.path, self.ip, self.loop, {}),
                local_addr=(self.ip, self.port))

        def start_loop(loop):
            asyncio.set_event_loop(loop)
            loop.run_forever()

        self.transport, protocol = self.loop.run_until_complete(listen)
        self.thread = threading.Thread(target=start_loop, args=(self.loop,))
        self.thread.start()
        return self

    def __exit__(self, type, value, exc):
        self.transport.close()
        self.loop.call_soon_threadsafe(self.loop.stop)
        self.thread.join()