diff options
Diffstat (limited to 'dpd/src')
-rw-r--r-- | dpd/src/Adapt.py | 174 | ||||
-rw-r--r-- | dpd/src/Agc.py | 165 | ||||
-rw-r--r-- | dpd/src/Dab_Util.py | 232 | ||||
-rw-r--r-- | dpd/src/Measure.py | 123 | ||||
-rw-r--r-- | dpd/src/Model.py | 336 | ||||
-rw-r--r-- | dpd/src/__init__.py | 0 | ||||
-rw-r--r-- | dpd/src/phase_align.py | 74 | ||||
-rwxr-xr-x | dpd/src/subsample_align.py | 83 | ||||
-rw-r--r-- | dpd/src/test_dab_Util.py | 62 | ||||
-rw-r--r-- | dpd/src/test_measure.py | 33 |
10 files changed, 1282 insertions, 0 deletions
diff --git a/dpd/src/Adapt.py b/dpd/src/Adapt.py new file mode 100644 index 0000000..2fb596f --- /dev/null +++ b/dpd/src/Adapt.py @@ -0,0 +1,174 @@ +# -*- coding: utf-8 -*- +# +# DPD Calculation Engine: updates ODR-DabMod's predistortion block. +# +# http://www.opendigitalradio.org +# Licence: The MIT License, see notice at the end of this file +""" +This module is used to change settings of ODR-DabMod using +the ZMQ remote control socket. +""" + +import zmq +import logging +import numpy as np + +class Adapt: + """Uses the ZMQ remote control to change parameters of the DabMod + + Parameters + ---------- + port : int + Port at which the ODR-DabMod is listening to connect the + ZMQ remote control. + """ + + def __init__(self, port, coef_path): + logging.info("Instantiate Adapt object") + self.port = port + self.coef_path = coef_path + self.host = "localhost" + self._context = zmq.Context() + + def _connect(self): + """Establish the connection to ODR-DabMod using + a ZMQ socket that is in request mode (Client). + Returns a socket""" + sock = self._context.socket(zmq.REQ) + sock.connect("tcp://%s:%d" % (self.host, self.port)) + + sock.send(b"ping") + data = [el.decode() for el in sock.recv_multipart()] + + if data != ['ok']: + raise RuntimeError( + "Could not ping server at %s %d: %s" % + (self.host, self.port, data)) + + return sock + + def send_receive(self, message): + """Send a message to ODR-DabMod. It always + returns the answer ODR-DabMod sends back. + + An example message could be + "get uhd txgain" or "set uhd txgain 50" + + Parameter + --------- + message : str + The message string that will be sent to the receiver. + """ + sock = self._connect() + logging.info("Send message: %s" % message) + msg_parts = message.split(" ") + for i, part in enumerate(msg_parts): + if i == len(msg_parts) - 1: + f = 0 + else: + f = zmq.SNDMORE + + sock.send(part.encode(), flags=f) + + data = [el.decode() for el in sock.recv_multipart()] + logging.info("Received message: %s" % message) + return data + + def set_txgain(self, gain): + """Set a new txgain for the ODR-DabMod. + + Parameters + ---------- + gain : int + new TX gain, in the same format as ODR-DabMod's config file + """ + # TODO this is specific to the B200 + if gain < 0 or gain > 89: + raise ValueError("Gain has to be in [0,89]") + return self.send_receive("set uhd txgain %d" % gain) + + def get_txgain(self): + """Get the txgain value in dB for the ODR-DabMod.""" + # TODO handle failure + return self.send_receive("get uhd txgain") + + def set_rxgain(self, gain): + """Set a new rxgain for the ODR-DabMod. + + Parameters + ---------- + gain : int + new RX gain, in the same format as ODR-DabMod's config file + """ + # TODO this is specific to the B200 + if gain < 0 or gain > 89: + raise ValueError("Gain has to be in [0,89]") + return self.send_receive("set uhd rxgain %d" % gain) + + def get_rxgain(self): + """Get the rxgain value in dB for the ODR-DabMod.""" + # TODO handle failure + return self.send_receive("get uhd rxgain") + + def _read_coef_file(self, path): + """Load the coefficients from the file in the format given in the README, + return ([AM coef], [PM coef])""" + coefs_am_out = [] + coefs_pm_out = [] + f = open(path, 'r') + lines = f.readlines() + n_coefs = int(lines[0]) + coefs = [float(l) for l in lines[1:]] + i = 0 + for c in coefs: + if i < n_coefs: + coefs_am_out.append(c) + elif i < 2*n_coefs: + coefs_pm_out.append(c) + else: + raise ValueError( + "Incorrect coef file format: too many coefficients in {}, should be {}, coefs are {}" + .format(path, n_coefs, coefs)) + i += 1 + f.close() + return (coefs_am_out, coefs_pm_out) + + def get_coefs(self): + return self._read_coef_file(self.coef_path) + + def _write_coef_file(self, coefs_am, coefs_pm, path): + assert(len(coefs_am) == len(coefs_pm)) + + f = open(path, 'w') + f.write("{}\n".format(len(coefs_am))) + for coef in coefs_am: + f.write("{}\n".format(coef)) + for coef in coefs_pm: + f.write("{}\n".format(coef)) + f.close() + + def set_coefs(self, coefs_am, coefs_pm): + self._write_coef_file(coefs_am, coefs_pm, self.coef_path) + self.send_receive("set memlesspoly coeffile {}".format(self.coef_path)) + +# The MIT License (MIT) +# +# Copyright (c) 2017 Andreas Steger, Matthias P. Braendli +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/dpd/src/Agc.py b/dpd/src/Agc.py new file mode 100644 index 0000000..1fd11c8 --- /dev/null +++ b/dpd/src/Agc.py @@ -0,0 +1,165 @@ +# -*- coding: utf-8 -*- +# +# Automatic Gain Control +# +# http://www.opendigitalradio.org +# Licence: The MIT License, see notice at the end of this file + +import datetime +import os +import logging +import time +logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) + +import numpy as np +import matplotlib +matplotlib.use('agg') +import matplotlib.pyplot as plt + +import src.Adapt as Adapt +import src.Measure as Measure + +class Agc: + """The goal of the automatic gain control is to set the + RX gain to a value at which all received amplitudes can + be detected. This means that the maximum possible amplitude + should be quantized at the highest possible digital value. + + A problem we have to face, is that the estimation of the + maximum amplitude by applying the max() function is very + unstable. This is due to the maximum’s rareness. Therefore + we estimate a far more robust value, such as the median, + and then approximate the maximum amplitude from it. + + Given this, we tune the RX gain in such a way, that the + received signal fulfills our desired property, of having + all amplitudes properly quantized.""" + + def __init__(self, measure, adapt, min_rxgain=25, peak_to_median=20): + assert isinstance(measure, Measure.Measure) + assert isinstance(adapt, Adapt.Adapt) + self.measure = measure + self.adapt = adapt + self.min_rxgain = min_rxgain + self.rxgain = self.min_rxgain + self.peak_to_median = peak_to_median + + def run(self): + self.adapt.set_rxgain(self.rxgain) + + for i in range(3): + # Measure + txframe_aligned, tx_ts, rxframe_aligned, rx_ts, rx_median= \ + self.measure.get_samples() + + # Estimate Maximum + rx_peak = self.peak_to_median * rx_median + correction_factor = 20*np.log10(1/rx_peak) + self.rxgain = self.rxgain + correction_factor + + assert self.min_rxgain <= self.rxgain, ("Desired RX Gain is {} which is smaller than the minimum of {}".format( + self.rxgain, self.min_rxgain)) + + logging.info("RX Median {:1.4f}, estimated peak {:1.4f}, correction factor {:1.4f}, new RX gain {:1.4f}".format( + rx_median, rx_peak, correction_factor, self.rxgain + )) + + self.adapt.set_rxgain(self.rxgain) + time.sleep(1) + + def plot_estimates(self): + """Plots the estimate of for Max, Median, Mean for different + number of samples.""" + self.adapt.set_rxgain(self.min_rxgain) + time.sleep(1) + + dt = datetime.datetime.now().isoformat() + fig_path = logging_path + "/" + dt + "_agc.pdf" + fig, axs = plt.subplots(2, 2, figsize=(3*6,1*6)) + axs = axs.ravel() + + for j in range(5): + txframe_aligned, tx_ts, rxframe_aligned, rx_ts, rx_median =\ + self.measure.get_samples() + + rxframe_aligned_abs = np.abs(rxframe_aligned) + + x = np.arange(100, rxframe_aligned_abs.shape[0], dtype=int) + rx_max_until = [] + rx_median_until = [] + rx_mean_until = [] + for i in x: + rx_max_until.append(np.max(rxframe_aligned_abs[:i])) + rx_median_until.append(np.median(rxframe_aligned_abs[:i])) + rx_mean_until.append(np.mean(rxframe_aligned_abs[:i])) + + axs[0].plot(x, + rx_max_until, + label="Run {}".format(j+1), + color=matplotlib.colors.hsv_to_rgb((1./(j+1.),0.8,0.7)), + linestyle="-", linewidth=0.25) + axs[0].set_xlabel("Samples") + axs[0].set_ylabel("Amplitude") + axs[0].set_title("Estimation for Maximum RX Amplitude") + axs[0].legend() + + axs[1].plot(x, + rx_median_until, + label="Run {}".format(j+1), + color=matplotlib.colors.hsv_to_rgb((1./(j+1.),0.9,0.7)), + linestyle="-", linewidth=0.25) + axs[1].set_xlabel("Samples") + axs[1].set_ylabel("Amplitude") + axs[1].set_title("Estimation for Median RX Amplitude") + axs[1].legend() + ylim_1 = axs[1].get_ylim() + + axs[2].plot(x, + rx_mean_until, + label="Run {}".format(j+1), + color=matplotlib.colors.hsv_to_rgb((1./(j+1.),0.9,0.7)), + linestyle="-", linewidth=0.25) + axs[2].set_xlabel("Samples") + axs[2].set_ylabel("Amplitude") + axs[2].set_title("Estimation for Mean RX Amplitude") + ylim_2 = axs[2].get_ylim() + axs[2].legend() + + axs[1].set_ylim(min(ylim_1[0], ylim_2[0]), + max(ylim_1[1], ylim_2[1])) + + fig.tight_layout() + fig.savefig(fig_path) + + axs[3].hist(rxframe_aligned_abs, bins=60) + axs[3].set_xlabel("Amplitude") + axs[3].set_ylabel("Frequency") + axs[3].set_title("Histogram of Amplitudes") + axs[3].legend() + + fig.tight_layout() + fig.savefig(fig_path) + fig.clf() + + +# The MIT License (MIT) +# +# Copyright (c) 2017 Andreas Steger +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/dpd/src/Dab_Util.py b/dpd/src/Dab_Util.py new file mode 100644 index 0000000..175b744 --- /dev/null +++ b/dpd/src/Dab_Util.py @@ -0,0 +1,232 @@ +# -*- coding: utf-8 -*- + +import datetime +import os +import logging +logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) + +import numpy as np +import scipy +import matplotlib +matplotlib.use('agg') +import matplotlib.pyplot as plt +import src.subsample_align as sa +import src.phase_align as pa +from scipy import signal + +class Dab_Util: + """Collection of methods that can be applied to an array + complex IQ samples of a DAB signal + """ + def __init__(self, sample_rate): + """ + :param sample_rate: sample rate [sample/sec] to use for calculations + """ + self.sample_rate = sample_rate + self.dab_bandwidth = 1536000 #Bandwidth of a dab signal + self.frame_ms = 96 #Duration of a Dab frame + + def lag(self, sig_orig, sig_rec): + """ + Find lag between two signals + Args: + sig_orig: The signal that has been sent + sig_rec: The signal that has been recored + """ + off = sig_rec.shape[0] + c = np.abs(signal.correlate(sig_orig, sig_rec)) + + if logging.getLogger().getEffectiveLevel() == logging.DEBUG: + dt = datetime.datetime.now().isoformat() + corr_path = (logging_path + "/" + dt + "_tx_rx_corr.pdf") + plt.plot(c, label="corr") + plt.legend() + plt.savefig(corr_path) + plt.clf() + + return np.argmax(c) - off + 1 + + def lag_upsampling(self, sig_orig, sig_rec, n_up): + if n_up != 1: + sig_orig_up = signal.resample(sig_orig, sig_orig.shape[0] * n_up) + sig_rec_up = signal.resample(sig_rec, sig_rec.shape[0] * n_up) + else: + sig_orig_up = sig_orig + sig_rec_up = sig_rec + l = self.lag(sig_orig_up, sig_rec_up) + l_orig = float(l) / n_up + return l_orig + + def subsample_align_upsampling(self, sig1, sig2, n_up=32): + """ + Returns an aligned version of sig1 and sig2 by cropping and subsample alignment + Using upsampling + """ + assert(sig1.shape[0] == sig2.shape[0]) + + if sig1.shape[0] % 2 == 1: + sig1 = sig1[:-1] + sig2 = sig2[:-1] + + sig1_up = signal.resample(sig1, sig1.shape[0] * n_up) + sig2_up = signal.resample(sig2, sig2.shape[0] * n_up) + + off_meas = self.lag_upsampling(sig2_up, sig1_up, n_up=1) + off = int(abs(off_meas)) + + if off_meas > 0: + sig1_up = sig1_up[:-off] + sig2_up = sig2_up[off:] + elif off_meas < 0: + sig1_up = sig1_up[off:] + sig2_up = sig2_up[:-off] + + sig1 = signal.resample(sig1_up, sig1_up.shape[0] / n_up).astype(np.complex64) + sig2 = signal.resample(sig2_up, sig2_up.shape[0] / n_up).astype(np.complex64) + return sig1, sig2 + + def subsample_align(self, sig1, sig2): + """ + Returns an aligned version of sig1 and sig2 by cropping and subsample alignment + """ + + if logging.getLogger().getEffectiveLevel() == logging.DEBUG: + dt = datetime.datetime.now().isoformat() + fig_path = logging_path + "/" + dt + "_sync_raw.pdf" + + fig, axs = plt.subplots(2) + axs[0].plot(np.abs(sig1[:128]), label="TX Frame") + axs[0].plot(np.abs(sig2[:128]), label="RX Frame") + axs[0].set_title("Raw Data") + axs[0].set_ylabel("Amplitude") + axs[0].set_xlabel("Samples") + axs[0].legend(loc=4) + + axs[1].plot(np.real(sig1[:128]), label="TX Frame") + axs[1].plot(np.real(sig2[:128]), label="RX Frame") + axs[1].set_title("Raw Data") + axs[1].set_ylabel("Real Part") + axs[1].set_xlabel("Samples") + axs[1].legend(loc=4) + + fig.tight_layout() + fig.savefig(fig_path) + fig.clf() + + logging.debug("Sig1_orig: %d %s, Sig2_orig: %d %s" % (len(sig1), sig1.dtype, len(sig2), sig2.dtype)) + off_meas = self.lag_upsampling(sig2, sig1, n_up=1) + off = int(abs(off_meas)) + + if off_meas > 0: + sig1 = sig1[:-off] + sig2 = sig2[off:] + elif off_meas < 0: + sig1 = sig1[off:] + sig2 = sig2[:-off] + + if off % 2 == 1: + sig1 = sig1[:-1] + sig2 = sig2[:-1] + + if logging.getLogger().getEffectiveLevel() == logging.DEBUG: + dt = datetime.datetime.now().isoformat() + fig_path = logging_path + "/" + dt + "_sync_sample_aligned.pdf" + + fig, axs = plt.subplots(2) + axs[0].plot(np.abs(sig1[:128]), label="TX Frame") + axs[0].plot(np.abs(sig2[:128]), label="RX Frame") + axs[0].set_title("Sample Aligned Data") + axs[0].set_ylabel("Amplitude") + axs[0].set_xlabel("Samples") + axs[0].legend(loc=4) + + axs[1].plot(np.real(sig1[:128]), label="TX Frame") + axs[1].plot(np.real(sig2[:128]), label="RX Frame") + axs[1].set_ylabel("Real Part") + axs[1].set_xlabel("Samples") + axs[1].legend(loc=4) + + fig.tight_layout() + fig.savefig(fig_path) + fig.clf() + + + sig2 = sa.subsample_align(sig2, sig1) + + if logging.getLogger().getEffectiveLevel() == logging.DEBUG: + dt = datetime.datetime.now().isoformat() + fig_path = logging_path + "/" + dt + "_sync_subsample_aligned.pdf" + + fig, axs = plt.subplots(2) + axs[0].plot(np.abs(sig1[:128]), label="TX Frame") + axs[0].plot(np.abs(sig2[:128]), label="RX Frame") + axs[0].set_title("Subsample Aligned") + axs[0].set_ylabel("Amplitude") + axs[0].set_xlabel("Samples") + axs[0].legend(loc=4) + + axs[1].plot(np.real(sig1[:128]), label="TX Frame") + axs[1].plot(np.real(sig2[:128]), label="RX Frame") + axs[1].set_ylabel("Real Part") + axs[1].set_xlabel("Samples") + axs[1].legend(loc=4) + + fig.tight_layout() + fig.savefig(fig_path) + fig.clf() + + sig2 = pa.phase_align(sig2, sig1) + + if logging.getLogger().getEffectiveLevel() == logging.DEBUG: + dt = datetime.datetime.now().isoformat() + fig_path = logging_path + "/" + dt + "_sync_phase_aligned.pdf" + + fig, axs = plt.subplots(2) + axs[0].plot(np.abs(sig1[:128]), label="TX Frame") + axs[0].plot(np.abs(sig2[:128]), label="RX Frame") + axs[0].set_title("Phase Aligned") + axs[0].set_ylabel("Amplitude") + axs[0].set_xlabel("Samples") + axs[0].legend(loc=4) + + axs[1].plot(np.real(sig1[:128]), label="TX Frame") + axs[1].plot(np.real(sig2[:128]), label="RX Frame") + axs[1].set_ylabel("Real Part") + axs[1].set_xlabel("Samples") + axs[1].legend(loc=4) + + fig.tight_layout() + fig.savefig(fig_path) + fig.clf() + + logging.debug("Sig1_cut: %d %s, Sig2_cut: %d %s, off: %d" % (len(sig1), sig1.dtype, len(sig2), sig2.dtype, off)) + return sig1, sig2 + + def fromfile(self, filename, offset=0, length=None): + if length is None: + return np.memmap(filename, dtype=np.complex64, mode='r', offset=64/8*offset) + else: + return np.memmap(filename, dtype=np.complex64, mode='r', offset=64/8*offset, shape=length) + + +# The MIT License (MIT) +# +# Copyright (c) 2017 Andreas Steger +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/dpd/src/Measure.py b/dpd/src/Measure.py new file mode 100644 index 0000000..e4fa8a2 --- /dev/null +++ b/dpd/src/Measure.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- + +import sys +import socket +import struct +import numpy as np +import logging +import src.Dab_Util as DU + +class Measure: + """Collect Measurement from DabMod""" + def __init__(self, samplerate, port, num_samples_to_request): + logging.info("Instantiate Measure object") + self.samplerate = samplerate + self.sizeof_sample = 8 # complex floats + self.port = port + self.num_samples_to_request = num_samples_to_request + + def _recv_exact(self, sock, num_bytes): + """Receive an exact number of bytes from a socket. This is + a wrapper around sock.recv() that can return less than the number + of requested bytes. + + Args: + sock (socket): Socket to receive data from. + num_bytes (int): Number of bytes that will be returned. + """ + bufs = [] + while num_bytes > 0: + b = sock.recv(num_bytes) + if len(b) == 0: + break + num_bytes -= len(b) + bufs.append(b) + return b''.join(bufs) + + def get_samples(self): + """Connect to ODR-DabMod, retrieve TX and RX samples, load + into numpy arrays, and return a tuple + (tx_timestamp, tx_samples, rx_timestamp, rx_samples) + where the timestamps are doubles, and the samples are numpy + arrays of complex floats, both having the same size + """ + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect(('localhost', self.port)) + + logging.debug("Send version") + s.sendall(b"\x01") + + logging.debug("Send request for {} samples".format(self.num_samples_to_request)) + s.sendall(struct.pack("=I", self.num_samples_to_request)) + + logging.debug("Wait for TX metadata") + num_samps, tx_second, tx_pps = struct.unpack("=III", self._recv_exact(s, 12)) + tx_ts = tx_second + tx_pps / 16384000.0 + + if num_samps > 0: + logging.debug("Receiving {} TX samples".format(num_samps)) + txframe_bytes = self._recv_exact(s, num_samps * self.sizeof_sample) + txframe = np.fromstring(txframe_bytes, dtype=np.complex64) + else: + txframe = np.array([], dtype=np.complex64) + + logging.debug("Wait for RX metadata") + rx_second, rx_pps = struct.unpack("=II", self._recv_exact(s, 8)) + rx_ts = rx_second + rx_pps / 16384000.0 + + if num_samps > 0: + logging.debug("Receiving {} RX samples".format(num_samps)) + rxframe_bytes = self._recv_exact(s, num_samps * self.sizeof_sample) + rxframe = np.fromstring(rxframe_bytes, dtype=np.complex64) + else: + rxframe = np.array([], dtype=np.complex64) + + # Normalize received signal with sent signal + rx_median = np.median(np.abs(rxframe)) + rxframe = rxframe / rx_median * np.median(np.abs(txframe)) + + if logging.getLogger().getEffectiveLevel() == logging.DEBUG: + logging.debug("txframe: min %f, max %f, median %f" % + (np.min(np.abs(txframe)), + np.max(np.abs(txframe)), + np.median(np.abs(txframe)))) + + logging.debug("rxframe: min %f, max %f, median %f" % + (np.min(np.abs(rxframe)), + np.max(np.abs(rxframe)), + np.median(np.abs(rxframe)))) + + logging.debug("Disconnecting") + s.close() + + du = DU.Dab_Util(self.samplerate) + txframe_aligned, rxframe_aligned = du.subsample_align(txframe, rxframe) + + logging.info( + "Measurement done, tx %d %s, rx %d %s, tx aligned %d %s, rx aligned %d %s" + % (len(txframe), txframe.dtype, len(rxframe), rxframe.dtype, + len(txframe_aligned), txframe_aligned.dtype, len(rxframe_aligned), rxframe_aligned.dtype) ) + + return txframe_aligned, tx_ts, rxframe_aligned, rx_ts, rx_median + +# The MIT License (MIT) +# +# Copyright (c) 2017 Andreas Steger +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/dpd/src/Model.py b/dpd/src/Model.py new file mode 100644 index 0000000..ae9f7b3 --- /dev/null +++ b/dpd/src/Model.py @@ -0,0 +1,336 @@ +# -*- coding: utf-8 -*- + +import datetime +import os +import logging +logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) + +from pynverse import inversefunc +import numpy as np +import matplotlib +matplotlib.use('agg') +import matplotlib.pyplot as plt +from sklearn.linear_model import Ridge + +class Model: + """Calculates new coefficients using the measurement and the old + coefficients""" + + def __init__(self, coefs_am, coefs_pm): + self.coefs_am = coefs_am + self.coefs_history = [coefs_am, ] + self.mses = [0, ] + self.errs = [0, ] + + self.coefs_pm = coefs_pm + self.coefs_pm_history = [coefs_pm, ] + self.errs_phase = [0, ] + + def sample_uniformly(self, txframe_aligned, rxframe_aligned, n_bins=4): + """This function returns tx and rx samples in a way + that the tx amplitudes have an approximate uniform + distribution with respect to the txframe_aligned amplitudes""" + txframe_aligned_abs = np.abs(txframe_aligned) + ccdf_min = 0 + ccdf_max = np.max(txframe_aligned_abs) + tx_hist, ccdf_edges = np.histogram(txframe_aligned_abs, + bins=n_bins, + range=(ccdf_min, ccdf_max)) + n_choise = np.min(tx_hist) + tx_choice = np.zeros(n_choise * n_bins, dtype=np.complex64) + rx_choice = np.zeros(n_choise * n_bins, dtype=np.complex64) + + for idx, bin in enumerate(tx_hist): + indices = np.where((txframe_aligned_abs >= ccdf_edges[idx]) & + (txframe_aligned_abs <= ccdf_edges[idx+1]))[0] + indices_choise = np.random.choice(indices, n_choise, replace=False) + rx_choice[idx*n_choise:(idx+1)*n_choise] = rxframe_aligned[indices_choise] + tx_choice[idx*n_choise:(idx+1)*n_choise] = txframe_aligned[indices_choise] + return tx_choice, rx_choice + + def get_next_coefs(self, txframe_aligned, rxframe_aligned): + tx_choice, rx_choice = self.sample_uniformly(txframe_aligned, rxframe_aligned) + + # Calculate new coefficients for AM/AM correction + rx_abs = np.abs(rx_choice) + rx_A = np.vstack([rx_abs, + rx_abs ** 3, + rx_abs ** 5, + rx_abs ** 7, + rx_abs ** 9, + ]).T + rx_dpd = np.sum(rx_A * self.coefs_am, axis=1) + rx_dpd = rx_dpd * ( + np.median(np.abs(tx_choice)) / np.median(np.abs(rx_dpd))) + + err = rx_dpd - np.abs(tx_choice) + self.errs.append(np.mean(np.abs(err ** 2))) + + a_delta = np.linalg.lstsq(rx_A, err)[0] + new_coefs = self.coefs_am - 0.1 * a_delta + new_coefs = new_coefs * (self.coefs_am[0] / new_coefs[0]) + logging.debug("a_delta {}".format(a_delta)) + logging.debug("new coefs_am {}".format(new_coefs)) + + # Calculate new coefficients for AM/PM correction + phase_diff_rad = (( + (np.angle(tx_choice) - + np.angle(rx_choice) + + np.pi) % (2 * np.pi)) - + np.pi + ) + + tx_abs = np.abs(tx_choice) + tx_abs_A = np.vstack([tx_abs, + tx_abs ** 2, + tx_abs ** 3, + tx_abs ** 4, + tx_abs ** 5, + ]).T + phase_dpd = np.sum(tx_abs_A * self.coefs_pm, axis=1) + + err_phase = phase_dpd - phase_diff_rad + self.errs_phase.append(np.mean(np.abs(err_phase ** 2))) + a_delta = np.linalg.lstsq(tx_abs_A, err_phase)[0] + new_coefs_pm = self.coefs_pm - 0.1 * a_delta + logging.debug("a_delta {}".format(a_delta)) + logging.debug("new new_coefs_pm {}".format(new_coefs_pm)) + + def dpd_phase(tx): + tx_abs = np.abs(tx) + tx_A_complex = np.vstack([tx, + tx * tx_abs ** 1, + tx * tx_abs ** 2, + tx * tx_abs ** 3, + tx * tx_abs ** 4, + ]).T + tx_dpd = np.sum(tx_A_complex * self.coefs_pm, axis=1) + return tx_dpd + + tx_range = np.linspace(0, 2) + phase_range_dpd = dpd_phase(tx_range) + + rx_A_complex = np.vstack([rx_choice, + rx_choice * rx_abs ** 2, + rx_choice * rx_abs ** 4, + rx_choice * rx_abs ** 6, + rx_choice * rx_abs ** 8, + ]).T + rx_post_distored = np.sum(rx_A_complex * self.coefs_am, axis=1) + rx_post_distored = rx_post_distored * ( + np.median(np.abs(tx_choice)) / + np.median(np.abs(rx_post_distored))) + mse = np.mean(np.abs((tx_choice - rx_post_distored) ** 2)) + logging.debug("MSE: {}".format(mse)) + self.mses.append(mse) + + def dpd(tx): + tx_abs = np.abs(tx) + tx_A_complex = np.vstack([tx, + tx * tx_abs ** 2, + tx * tx_abs ** 4, + tx * tx_abs ** 6, + tx * tx_abs ** 8, + ]).T + tx_dpd = np.sum(tx_A_complex * self.coefs_am, axis=1) + return tx_dpd + + rx_range = np.linspace(0, 1, num=100) + rx_range_dpd = dpd(rx_range) + rx_range = rx_range[(rx_range_dpd > 0) & (rx_range_dpd < 2)] + rx_range_dpd = rx_range_dpd[(rx_range_dpd > 0) & (rx_range_dpd < 2)] + + if logging.getLogger().getEffectiveLevel() == logging.DEBUG: + logging.debug("txframe: min %f, max %f, median %f" % + (np.min(np.abs(txframe_aligned)), + np.max(np.abs(txframe_aligned)), + np.median(np.abs(txframe_aligned)) + )) + + logging.debug("rxframe: min %f, max %f, median %f" % + (np.min(np.abs(rx_choice)), + np.max(np.abs(rx_choice)), + np.median(np.abs(rx_choice)) + )) + + dt = datetime.datetime.now().isoformat() + fig_path = logging_path + "/" + dt + "_Model.pdf" + + fig = plt.figure(figsize=(3*6, 1.5 * 6)) + + ax = plt.subplot(3,3,1) + ax.plot(np.abs(txframe_aligned[:128]), + label="TX sent", + linestyle=":") + ax.plot(np.abs(rxframe_aligned[:128]), + label="RX received", + color="red") + ax.set_title("Synchronized Signals of Iteration {}".format(len(self.coefs_history))) + ax.set_xlabel("Samples") + ax.set_ylabel("Amplitude") + ax.text(0, 0, "TX (max {:01.3f}, mean {:01.3f}, median {:01.3f})".format( + np.max(np.abs(txframe_aligned)), + np.mean(np.abs(txframe_aligned)), + np.median(np.abs(txframe_aligned)) + ), size = 8) + ax.legend(loc=4) + + ax = plt.subplot(3,3,2) + ax.plot(np.real(txframe_aligned[:128]), + label="TX sent", + linestyle=":") + ax.plot(np.real(rxframe_aligned[:128]), + label="RX received", + color="red") + ax.set_title("Synchronized Signals") + ax.set_xlabel("Samples") + ax.set_ylabel("Real Part") + ax.legend(loc=4) + + ax = plt.subplot(3,3,3) + ax.plot(np.abs(txframe_aligned[:128]), + label="TX Frame", + linestyle=":", + linewidth=0.5) + ax.plot(np.abs(rxframe_aligned[:128]), + label="RX Frame", + linestyle="--", + linewidth=0.5) + + rx_abs = np.abs(rxframe_aligned) + rx_A = np.vstack([rx_abs, + rx_abs ** 3, + rx_abs ** 5, + rx_abs ** 7, + rx_abs ** 9, + ]).T + rx_dpd = np.sum(rx_A * self.coefs_am, axis=1) + rx_dpd = rx_dpd * ( + np.median(np.abs(tx_choice)) / np.median(np.abs(rx_dpd))) + + ax.plot(np.abs(rx_dpd[:128]), + label="RX DPD Frame", + linestyle="-.", + linewidth=0.5) + + tx_abs = np.abs(np.abs(txframe_aligned[:128])) + tx_A = np.vstack([tx_abs, + tx_abs ** 3, + tx_abs ** 5, + tx_abs ** 7, + tx_abs ** 9, + ]).T + tx_dpd = np.sum(tx_A * new_coefs, axis=1) + tx_dpd_norm = tx_dpd * ( + np.median(np.abs(tx_choice)) / np.median(np.abs(tx_dpd))) + + ax.plot(np.abs(tx_dpd_norm[:128]), + label="TX DPD Frame Norm", + linestyle="-.", + linewidth=0.5) + ax.legend(loc=4) + ax.set_title("RX DPD") + ax.set_xlabel("Samples") + ax.set_ylabel("Amplitude") + + ax = plt.subplot(3,3,4) + ax.scatter( + np.abs(tx_choice[:1024]), + np.abs(rx_choice[:1024]), + s=0.1) + ax.plot(rx_range_dpd / self.coefs_am[0], rx_range, linewidth=0.25) + ax.set_title("Amplifier Characteristic") + ax.set_xlabel("TX Amplitude") + ax.set_ylabel("RX Amplitude") + + ax = plt.subplot(3,3,5) + ax.scatter( + np.abs(tx_choice[:1024]), + phase_diff_rad[:1024] * 180 / np.pi, + s=0.1 + ) + ax.plot(tx_range, phase_range_dpd * 180 / np.pi, linewidth=0.25) + ax.set_title("Amplifier Characteristic") + ax.set_xlabel("TX Amplitude") + ax.set_ylabel("Phase Difference [deg]") + + ax = plt.subplot(3,3,6) + ccdf_min, ccdf_max = 0, 1 + tx_hist, ccdf_edges = np.histogram(np.abs(txframe_aligned), + bins=60, + range=(ccdf_min, ccdf_max)) + tx_hist_normalized = tx_hist.astype(float)/np.sum(tx_hist) + ccdf = 1.0 - np.cumsum(tx_hist_normalized) + ax.semilogy(ccdf_edges[:-1], ccdf, label="CCDF") + ax.semilogy(ccdf_edges[:-1], + tx_hist_normalized, + label="Histogram", + drawstyle='steps') + ax.legend(loc=4) + ax.set_ylim(1e-5,2) + ax.set_title("Complementary Cumulative Distribution Function") + ax.set_xlabel("TX Amplitude") + ax.set_ylabel("Ratio of Samples larger than x") + + ax = plt.subplot(3,3,7) + coefs_history = np.array(self.coefs_history) + for idx, coef_hist in enumerate(coefs_history.T): + ax.plot(coef_hist, + label="Coef {}".format(idx), + linewidth=0.5) + ax.legend(loc=4) + ax.set_title("AM/AM Coefficient History") + ax.set_xlabel("Iterations") + ax.set_ylabel("Coefficient Value") + + ax = plt.subplot(3,3,8) + coefs_history = np.array(self.coefs_pm_history) + for idx, coef_hist in enumerate(coefs_history.T): + ax.plot(coef_hist, + label="Coef {}".format(idx), + linewidth=0.5) + ax.legend(loc=4) + ax.set_title("AM/PM Coefficient History") + ax.set_xlabel("Iterations") + ax.set_ylabel("Coefficient Value") + + ax = plt.subplot(3,3,9) + coefs_history = np.array(self.coefs_history) + ax.plot(self.mses, label="MSE") + ax.plot(self.errs, label="ERR") + ax.legend(loc=4) + ax.set_title("MSE History") + ax.set_xlabel("Iterations") + ax.set_ylabel("MSE") + + fig.tight_layout() + fig.savefig(fig_path) + fig.clf() + + self.coefs_am = new_coefs + self.coefs_history.append(self.coefs_am) + self.coefs_pm = new_coefs_pm + self.coefs_pm_history.append(self.coefs_pm) + return self.coefs_am, self.coefs_pm + +# The MIT License (MIT) +# +# Copyright (c) 2017 Andreas Steger +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/dpd/src/__init__.py b/dpd/src/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/dpd/src/__init__.py diff --git a/dpd/src/phase_align.py b/dpd/src/phase_align.py new file mode 100644 index 0000000..f03184b --- /dev/null +++ b/dpd/src/phase_align.py @@ -0,0 +1,74 @@ +import datetime +import os +import logging +logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) + +import numpy as np +from scipy import signal, optimize +import sys +import matplotlib.pyplot as plt + + +def phase_align(sig, ref_sig): + """Do phase alignment for sig relative to the reference signal + ref_sig. + + Returns the aligned signal""" + + angle_diff = (np.angle(sig) - np.angle(ref_sig)) % (2. * np.pi) + + real_diffs = np.cos(angle_diff) + imag_diffs = np.sin(angle_diff) + + if logging.getLogger().getEffectiveLevel() == logging.DEBUG: + dt = datetime.datetime.now().isoformat() + fig_path = logging_path + "/" + dt + "_phase_align.pdf" + + plt.subplot(511) + plt.hist(angle_diff, bins=60, label="Angle Diff") + plt.xlabel("Angle") + plt.ylabel("Count") + plt.legend(loc=4) + + plt.subplot(512) + plt.hist(real_diffs, bins=60, label="Real Diff") + plt.xlabel("Real Part") + plt.ylabel("Count") + plt.legend(loc=4) + + plt.subplot(513) + plt.hist(imag_diffs, bins=60, label="Imaginary Diff") + plt.xlabel("Imaginary Part") + plt.ylabel("Count") + plt.legend(loc=4) + + plt.subplot(514) + plt.plot(np.angle(ref_sig[:128]), label="ref_sig") + plt.plot(np.angle(sig[:128]), label="sig") + plt.xlabel("Angle") + plt.ylabel("Sample") + plt.legend(loc=4) + + real_diff = np.median(real_diffs) + imag_diff = np.median(imag_diffs) + + angle = np.angle(real_diff + 1j * imag_diff) + + logging.debug( + "Compensating phase by {} rad, {} degree. real median {}, imag median {}".format( + angle, angle*180./np.pi, real_diff, imag_diff + )) + sig = sig * np.exp(1j * -angle) + + if logging.getLogger().getEffectiveLevel() == logging.DEBUG: + plt.subplot(515) + plt.plot(np.angle(ref_sig[:128]), label="ref_sig") + plt.plot(np.angle(sig[:128]), label="sig") + plt.xlabel("Angle") + plt.ylabel("Sample") + plt.legend(loc=4) + plt.tight_layout() + plt.savefig(fig_path) + plt.clf() + + return sig diff --git a/dpd/src/subsample_align.py b/dpd/src/subsample_align.py new file mode 100755 index 0000000..0a51593 --- /dev/null +++ b/dpd/src/subsample_align.py @@ -0,0 +1,83 @@ +import datetime +import os +import logging +logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) + +import numpy as np +from scipy import signal, optimize +import matplotlib.pyplot as plt + +def gen_omega(length): + if (length % 2) == 1: + raise ValueError("Needs an even length array.") + + halflength = int(length/2) + factor = 2.0 * np.pi / length + + omega = np.zeros(length, dtype=np.float) + for i in range(halflength): + omega[i] = factor * i + + for i in range(halflength, length): + omega[i] = factor * (i - length) + + return omega + +def subsample_align(sig, ref_sig): + """Do subsample alignment for sig relative to the reference signal + ref_sig. The delay between the two must be less than sample + + Returns the aligned signal""" + + n = len(sig) + if (n % 2) == 1: + raise ValueError("Needs an even length signal.") + halflen = int(n/2) + + fft_sig = np.fft.fft(sig) + + omega = gen_omega(n) + + def correlate_for_delay(tau): + # A subsample offset between two signals corresponds, in the frequency + # domain, to a linearly increasing phase shift, whose slope + # corresponds to the delay. + # + # Here, we build this phase shift in rotate_vec, and multiply it with + # our signal. + + rotate_vec = np.exp(1j * tau * omega) + # zero-frequency is rotate_vec[0], so rotate_vec[N/2] is the + # bin corresponding to the [-1, 1, -1, 1, ...] time signal, which + # is both the maximum positive and negative frequency. + # I don't remember why we handle it differently. + rotate_vec[halflen] = np.cos(np.pi * tau) + + corr_sig = np.fft.ifft(rotate_vec * fft_sig) + + return -np.abs(np.sum(np.conj(corr_sig) * ref_sig)) + + optim_result = optimize.minimize_scalar(correlate_for_delay, bounds=(-1,1), method='bounded', options={'disp': True}) + + if optim_result.success: + best_tau = optim_result.x + + if 1: + corr = np.vectorize(correlate_for_delay) + ixs = np.linspace(-1, 1, 100) + taus = corr(ixs) + + dt = datetime.datetime.now().isoformat() + tau_path = (logging_path + "/" + dt + "_tau.pdf") + plt.plot(ixs, taus) + plt.title("Subsample correlation, minimum is best: {}".format(best_tau)) + plt.savefig(tau_path) + plt.clf() + + # Prepare rotate_vec = fft_sig with rotated phase + rotate_vec = np.exp(1j * best_tau * omega) + rotate_vec[halflen] = np.cos(np.pi * best_tau) + return np.fft.ifft(rotate_vec * fft_sig).astype(np.complex64) + else: + #print("Could not optimize: " + optim_result.message) + return np.zeros(0, dtype=np.complex64) diff --git a/dpd/src/test_dab_Util.py b/dpd/src/test_dab_Util.py new file mode 100644 index 0000000..0b2fa4f --- /dev/null +++ b/dpd/src/test_dab_Util.py @@ -0,0 +1,62 @@ +from unittest import TestCase + +import numpy as np +import pandas as pd +import src.Dab_Util as DU + +class TestDab_Util(TestCase): + + def test_subsample_align(self, sample_orig=r'../test_data/orig_rough_aligned.dat', + sample_rec =r'../test_data/recored_rough_aligned.dat', + length = 10240, max_size = 1000000): + du = DU.Dab_Util(8196000) + res1 = [] + res2 = [] + for i in range(10): + start = np.random.randint(50, max_size) + r = np.random.randint(-50, 50) + + s1 = du.fromfile(sample_orig, offset=start+r, length=length) + s2 = du.fromfile(sample_rec, offset=start, length=length) + + res1.append(du.lag_upsampling(s2, s1, 32)) + + s1_aligned, s2_aligned = du.subsample_align(s1, s2) + + res2.append(du.lag_upsampling(s2_aligned, s1_aligned, 32)) + + error_rate = np.mean(np.array(res2) != 0) + self.assertEqual(error_rate, 0.0, "The error rate for aligning was %.2f%%" + % error_rate * 100) + +#def test_using_aligned_pair(sample_orig=r'../data/orig_rough_aligned.dat', sample_rec =r'../data/recored_rough_aligned.dat', length = 10240, max_size = 1000000): +# res = [] +# for i in tqdm(range(100)): +# start = np.random.randint(50, max_size) +# r = np.random.randint(-50, 50) +# +# s1 = du.fromfile(sample_orig, offset=start+r, length=length) +# s2 = du.fromfile(sample_rec, offset=start, length=length) +# +# res.append({'offset':r, +# '1':r - du.lag_upsampling(s2, s1, n_up=1), +# '2':r - du.lag_upsampling(s2, s1, n_up=2), +# '3':r - du.lag_upsampling(s2, s1, n_up=3), +# '4':r - du.lag_upsampling(s2, s1, n_up=4), +# '8':r - du.lag_upsampling(s2, s1, n_up=8), +# '16':r - du.lag_upsampling(s2, s1, n_up=16), +# '32':r - du.lag_upsampling(s2, s1, n_up=32), +# }) +# df = pd.DataFrame(res) +# df = df.reindex_axis(sorted(df.columns), axis=1) +# print(df.describe()) +# +# +#print("Align using upsampling") +#for n_up in [1, 2, 3, 4, 7, 8, 16]: +# correct_ratio = test_phase_offset(lambda x,y: du.lag_upsampling(x,y,n_up), tol=1./n_up) +# print("%.1f%% of the tested offsets were measured within tolerance %.4f for n_up = %d" % (correct_ratio * 100, 1./n_up, n_up)) +#test_using_aligned_pair() +# +#print("Phase alignment") +#test_subsample_alignment() diff --git a/dpd/src/test_measure.py b/dpd/src/test_measure.py new file mode 100644 index 0000000..b695721 --- /dev/null +++ b/dpd/src/test_measure.py @@ -0,0 +1,33 @@ +from unittest import TestCase +from Measure import Measure +import socket + + +class TestMeasure(TestCase): + + def _open_socks(self): + sock_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock_server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_server.bind(('localhost', 1234)) + sock_server.listen(1) + + sock_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock_client.connect(('localhost', 1234)) + + conn_server, addr_server = sock_server.accept() + return conn_server, sock_client + + def test__recv_exact(self): + m = Measure(1234, 1) + payload = b"test payload" + + conn_server, sock_client = self._open_socks() + conn_server.send(payload) + rec = m._recv_exact(sock_client, len(payload)) + + self.assertEqual(rec, payload, + "Did not receive the same message as sended. (%s, %s)" % + (rec, payload)) + + def test_get_samples(self): + self.fail() |