diff options
author | Matthias P. Braendli <matthias.braendli@mpb.li> | 2017-12-29 09:30:47 +0100 |
---|---|---|
committer | Matthias P. Braendli <matthias.braendli@mpb.li> | 2017-12-29 09:30:47 +0100 |
commit | 0c0f828c6bccee3aeb3049cb8b5bb480153cd3b6 (patch) | |
tree | 520dc4ff15dbc8dba056ea03d762d570b243f27d /dpd/src | |
parent | 9234155749be0c9ee3ae1269f47c2240d302c21a (diff) | |
parent | 8e3338479c180418a05ab030c60ba01c2a8615ca (diff) | |
download | dabmod-0c0f828c6bccee3aeb3049cb8b5bb480153cd3b6.tar.gz dabmod-0c0f828c6bccee3aeb3049cb8b5bb480153cd3b6.tar.bz2 dabmod-0c0f828c6bccee3aeb3049cb8b5bb480153cd3b6.zip |
Merge branch 'next' into outputRefactoring
Diffstat (limited to 'dpd/src')
-rw-r--r-- | dpd/src/Adapt.py | 35 | ||||
-rw-r--r-- | dpd/src/Dab_Util.py | 26 | ||||
-rw-r--r-- | dpd/src/ExtractStatistic.py | 8 | ||||
-rw-r--r-- | dpd/src/GlobalConfig.py (renamed from dpd/src/Const.py) | 39 | ||||
-rw-r--r-- | dpd/src/MER.py | 13 | ||||
-rw-r--r-- | dpd/src/Measure.py | 7 | ||||
-rw-r--r-- | dpd/src/Measure_Shoulders.py | 8 | ||||
-rw-r--r-- | dpd/src/Model.py | 6 | ||||
-rw-r--r-- | dpd/src/Model_AM.py | 7 | ||||
-rw-r--r-- | dpd/src/Model_Lut.py | 4 | ||||
-rw-r--r-- | dpd/src/Model_PM.py | 7 | ||||
-rw-r--r-- | dpd/src/Model_Poly.py | 6 | ||||
-rw-r--r-- | dpd/src/RX_Agc.py | 7 | ||||
-rw-r--r-- | dpd/src/Symbol_align.py | 10 | ||||
-rw-r--r-- | dpd/src/TX_Agc.py | 3 | ||||
-rw-r--r-- | dpd/src/phase_align.py | 6 | ||||
-rwxr-xr-x | dpd/src/subsample_align.py | 10 |
17 files changed, 92 insertions, 110 deletions
diff --git a/dpd/src/Adapt.py b/dpd/src/Adapt.py index 7e19a2c..329ee20 100644 --- a/dpd/src/Adapt.py +++ b/dpd/src/Adapt.py @@ -16,8 +16,6 @@ import os import datetime import pickle -logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) - LUT_LEN = 32 FORMAT_POLY = 1 FORMAT_LUT = 2 @@ -44,6 +42,19 @@ def _write_lut_file(scalefactor, lut, path): f.write("{}\n{}\n".format(coef.real, coef.imag)) f.close() +def dpddata_to_str(dpddata): + if dpddata[0] == "poly": + coefs_am = dpddata[1] + coefs_pm = dpddata[2] + return "dpd_coefs_am {}, dpd_coefs_pm {}".format( + coefs_am, coefs_pm) + elif dpddata[0] == "lut": + scalefactor = dpddata[1] + lut = dpddata[2] + return "LUT scalefactor {}, LUT {}".format( + scalefactor, lut) + else: + raise ValueError("Unknown dpddata type {}".format(dpddata[0])) class Adapt: """Uses the ZMQ remote control to change parameters of the DabMod @@ -55,8 +66,9 @@ class Adapt: ZMQ remote control. """ - def __init__(self, port, coef_path): + def __init__(self, config, port, coef_path): logging.debug("Instantiate Adapt object") + self.c = config self.port = port self.coef_path = coef_path self.host = "localhost" @@ -226,27 +238,30 @@ class Adapt: """Backup current settings to a file""" dt = datetime.datetime.now().isoformat() if path is None: - path = logging_path + "/" + dt + "_adapt.pkl" + if self.c.plot_location is not None: + path = self.c.plot_location + "/" + dt + "_adapt.pkl" + else: + raise Exception("Cannot dump Adapt without either plot_location or path set") d = { "txgain": self.get_txgain(), "rxgain": self.get_rxgain(), "digital_gain": self.get_digital_gain(), "predistorter": self.get_predistorter() } - with open(path, "w") as f: + with open(path, "wb") as f: pickle.dump(d, f) return path def load(self, path): """Restore settings from a file""" - with open(path, "r") as f: + with open(path, "rb") as f: d = pickle.load(f) - self.set_txgain(d["txgain"]) - self.set_digital_gain(d["digital_gain"]) - self.set_rxgain(d["rxgain"]) - self.set_predistorter(d["predistorter"]) + self.set_txgain(d["txgain"]) + self.set_digital_gain(d["digital_gain"]) + self.set_rxgain(d["rxgain"]) + self.set_predistorter(d["predistorter"]) # The MIT License (MIT) # diff --git a/dpd/src/Dab_Util.py b/dpd/src/Dab_Util.py index 2021f38..bc89a39 100644 --- a/dpd/src/Dab_Util.py +++ b/dpd/src/Dab_Util.py @@ -8,9 +8,6 @@ import datetime import os import logging - -logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) - import numpy as np import matplotlib @@ -33,10 +30,11 @@ class Dab_Util: complex IQ samples of a DAB signal """ - def __init__(self, sample_rate, plot=False): + def __init__(self, config, sample_rate, plot=False): """ :param sample_rate: sample rate [sample/sec] to use for calculations """ + self.c = config self.sample_rate = sample_rate self.dab_bandwidth = 1536000 # Bandwidth of a dab signal self.frame_ms = 96 # Duration of a Dab frame @@ -53,9 +51,9 @@ class Dab_Util: off = sig_rec.shape[0] c = np.abs(signal.correlate(sig_orig, sig_rec)) - if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot: + if self.plot and self.c.plot_location is not None: dt = datetime.datetime.now().isoformat() - corr_path = (logging_path + "/" + dt + "_tx_rx_corr.svg") + corr_path = self.c.plot_location + "/" + dt + "_tx_rx_corr.png" plt.plot(c, label="corr") plt.legend() plt.savefig(corr_path) @@ -107,9 +105,9 @@ class Dab_Util: Returns an aligned version of sig_tx and sig_rx by cropping and subsample alignment """ - if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot: + if self.plot and self.c.plot_location is not None: dt = datetime.datetime.now().isoformat() - fig_path = logging_path + "/" + dt + "_sync_raw.svg" + fig_path = self.c.plot_location + "/" + dt + "_sync_raw.png" fig, axs = plt.subplots(2) axs[0].plot(np.abs(sig_tx[:128]), label="TX Frame") @@ -151,9 +149,9 @@ class Dab_Util: sig_tx = sig_tx[:-1] sig_rx = sig_rx[:-1] - if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot: + if self.plot and self.c.plot_location is not None: dt = datetime.datetime.now().isoformat() - fig_path = logging_path + "/" + dt + "_sync_sample_aligned.svg" + fig_path = self.c.plot_location + "/" + dt + "_sync_sample_aligned.png" fig, axs = plt.subplots(2) axs[0].plot(np.abs(sig_tx[:128]), label="TX Frame") @@ -175,9 +173,9 @@ class Dab_Util: sig_rx = sa.subsample_align(sig_rx, sig_tx) - if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot: + if self.plot and self.c.plot_location is not None: dt = datetime.datetime.now().isoformat() - fig_path = logging_path + "/" + dt + "_sync_subsample_aligned.svg" + fig_path = self.c.plot_location + "/" + dt + "_sync_subsample_aligned.png" fig, axs = plt.subplots(2) axs[0].plot(np.abs(sig_tx[:128]), label="TX Frame") @@ -199,9 +197,9 @@ class Dab_Util: sig_rx = pa.phase_align(sig_rx, sig_tx) - if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot: + if self.plot and self.c.plot_location is not None: dt = datetime.datetime.now().isoformat() - fig_path = logging_path + "/" + dt + "_sync_phase_aligned.svg" + fig_path = self.c.plot_location + "/" + dt + "_sync_phase_aligned.png" fig, axs = plt.subplots(2) axs[0].plot(np.abs(sig_tx[:128]), label="TX Frame") diff --git a/dpd/src/ExtractStatistic.py b/dpd/src/ExtractStatistic.py index d27cd77..639513a 100644 --- a/dpd/src/ExtractStatistic.py +++ b/dpd/src/ExtractStatistic.py @@ -8,13 +8,10 @@ import numpy as np import matplotlib.pyplot as plt - import datetime import os import logging -logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) - def _check_input_extract(tx_dpd, rx_received): # Check data type @@ -64,10 +61,9 @@ class ExtractStatistic: self.plot = c.ES_plot def _plot_and_log(self, tx_values, rx_values, phase_diffs_values, phase_diffs_values_lists): - if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot: - + if self.plot and self.c.plot_location is not None: dt = datetime.datetime.now().isoformat() - fig_path = logging_path + "/" + dt + "_ExtractStatistic.png" + fig_path = self.c.plot_location + "/" + dt + "_ExtractStatistic.png" sub_rows = 3 sub_cols = 1 fig = plt.figure(figsize=(sub_cols * 6, sub_rows / 2. * 6)) diff --git a/dpd/src/Const.py b/dpd/src/GlobalConfig.py index 6c9bafa..b84b9d7 100644 --- a/dpd/src/Const.py +++ b/dpd/src/GlobalConfig.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # -# DPD Computation Engine, constants. +# DPD Computation Engine, constants and global configuration # # Source for DAB standard: etsi_EN_300_401_v010401p p145 # @@ -9,19 +9,20 @@ import numpy as np +class GlobalConfig: + def __init__(self, cli_args, plot_location): + self.sample_rate = cli_args.samplerate + assert self.sample_rate == 8192000 # By now only constants for 8192000 -class Const: - def __init__(self, sample_rate, target_median, plot): - assert sample_rate == 8192000 # By now only constants for 8192000 - self.sample_rate = sample_rate + self.plot_location = plot_location # DAB frame # Time domain - self.T_F = sample_rate / 2048000 * 196608 # Transmission frame duration - self.T_NULL = sample_rate / 2048000 * 2656 # Null symbol duration - self.T_S = sample_rate / 2048000 * 2552 # Duration of OFDM symbols of indices l = 1, 2, 3,... L; - self.T_U = sample_rate / 2048000 * 2048 # Inverse of carrier spacing - self.T_C = sample_rate / 2048000 * 504 # Duration of cyclic prefix + self.T_F = self.sample_rate / 2048000 * 196608 # Transmission frame duration + self.T_NULL = self.sample_rate / 2048000 * 2656 # Null symbol duration + self.T_S = self.sample_rate / 2048000 * 2552 # Duration of OFDM symbols of indices l = 1, 2, 3,... L; + self.T_U = self.sample_rate / 2048000 * 2048 # Inverse of carrier spacing + self.T_C = self.sample_rate / 2048000 * 504 # Duration of cyclic prefix # Frequency Domain # example: np.delete(fft[3328:4865], 768) @@ -34,10 +35,10 @@ class Const: # time per sample = 1 / sample_rate # frequency per bin = 1kHz # phase difference per sample offset = delta_t * 2 * pi * delta_freq - self.phase_offset_per_sample = 1. / sample_rate * 2 * np.pi * 1000 + self.phase_offset_per_sample = 1. / self.sample_rate * 2 * np.pi * 1000 # Constants for ExtractStatistic - self.ES_plot = plot + self.ES_plot = cli_args.plot self.ES_start = 0.0 self.ES_end = 1.0 self.ES_n_bins = 64 # Number of bins between ES_start and ES_end @@ -45,7 +46,7 @@ class Const: # Constants for Measure_Shoulder self.MS_enable = False - self.MS_plot = plot + self.MS_plot = cli_args.plot meas_offset = 976 # Offset from center frequency to measure shoulder [kHz] meas_width = 100 # Size of frequency delta to measure shoulder [kHz] @@ -63,24 +64,24 @@ class Const: self.MS_n_proc = 4 # Constants for MER - self.MER_plot = plot + self.MER_plot = cli_args.plot # Constants for Model - self.MDL_plot = True or plot # Override default + self.MDL_plot = cli_args.plot # Constants for Model_PM # Set all phase offsets to zero for TX amplitude < MPM_tx_min self.MPM_tx_min = 0.1 # Constants for TX_Agc - self.TAGC_max_txgain = 89 # USRP specific - self.TAGC_tx_median_target = target_median + self.TAGC_max_txgain = 89 # USRP B200 specific + self.TAGC_tx_median_target = cli_args.target_median self.TAGC_tx_median_max = self.TAGC_tx_median_target * 1.4 self.TAGC_tx_median_min = self.TAGC_tx_median_target / 1.4 # Constants for RX_AGC - self.RAGC_min_rxgain = 25 # USRP specific - self.RAGC_rx_median_target = self.TAGC_tx_median_target + self.RAGC_min_rxgain = 25 # USRP B200 specific + self.RAGC_rx_median_target = cli_args.target_median # The MIT License (MIT) # diff --git a/dpd/src/MER.py b/dpd/src/MER.py index f186261..693058d 100644 --- a/dpd/src/MER.py +++ b/dpd/src/MER.py @@ -8,11 +8,6 @@ import datetime import os import logging -try: - logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) -except: - logging_path = "/tmp/" - import numpy as np import matplotlib matplotlib.use('agg') @@ -76,9 +71,11 @@ class MER: spectrum = self._calc_spectrum(tx) - if self.plot: + if self.plot and self.c.plot_location is not None: dt = datetime.datetime.now().isoformat() - fig_path = logging_path + "/" + dt + "_MER" + debug_name + ".svg" + fig_path = self.c.plot_location + "/" + dt + "_MER" + debug_name + ".png" + else: + fig_path = None MERs = [] for i, (x, y) in enumerate(self._split_in_carrier( @@ -103,7 +100,7 @@ class MER: ylim = ax.get_ylim() ax.set_ylim(ylim[0] - (ylim[1] - ylim[0]) * 0.1, ylim[1]) - if self.plot: + if fig_path is not None: plt.tight_layout() plt.savefig(fig_path) plt.show() diff --git a/dpd/src/Measure.py b/dpd/src/Measure.py index d4b1d9e..b7423c6 100644 --- a/dpd/src/Measure.py +++ b/dpd/src/Measure.py @@ -11,14 +11,13 @@ import struct import numpy as np import src.Dab_Util as DU import os - import logging -logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) class Measure: """Collect Measurement from DabMod""" - def __init__(self, samplerate, port, num_samples_to_request): + def __init__(self, config, samplerate, port, num_samples_to_request): logging.info("Instantiate Measure object") + self.c = config self.samplerate = samplerate self.sizeof_sample = 8 # complex floats self.port = port @@ -106,7 +105,7 @@ class Measure: rx_median = np.median(np.abs(rxframe)) rxframe = rxframe / rx_median * np.median(np.abs(txframe)) - du = DU.Dab_Util(self.samplerate) + du = DU.Dab_Util(self.c, self.samplerate) txframe_aligned, rxframe_aligned = du.subsample_align(txframe, rxframe) logging.info( diff --git a/dpd/src/Measure_Shoulders.py b/dpd/src/Measure_Shoulders.py index c733dfd..fd90050 100644 --- a/dpd/src/Measure_Shoulders.py +++ b/dpd/src/Measure_Shoulders.py @@ -9,9 +9,6 @@ import datetime import os import logging import multiprocessing - -logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) - import numpy as np import matplotlib.pyplot as plt @@ -79,8 +76,11 @@ class Measure_Shoulders: self.plot = c.MS_plot def _plot(self, signal): + if self.c.plot_location is None: + return + dt = datetime.datetime.now().isoformat() - fig_path = logging_path + "/" + dt + "_sync_subsample_aligned.svg" + fig_path = self.c.plot_location + "/" + dt + "_sync_subsample_aligned.png" fft = calc_fft_db(signal, 100, self.c) peak, idxs_peak = _calc_peak(fft, self.c) diff --git a/dpd/src/Model.py b/dpd/src/Model.py index 67feeb6..b2c303f 100644 --- a/dpd/src/Model.py +++ b/dpd/src/Model.py @@ -2,6 +2,12 @@ from src.Model_Poly import Poly from src.Model_Lut import Lut +def select_model_from_dpddata(dpddata): + if dpddata[0] == 'lut': + return Lut + elif dpddata[0] == 'poly': + return Poly + # The MIT License (MIT) # # Copyright (c) 2017 Andreas Steger diff --git a/dpd/src/Model_AM.py b/dpd/src/Model_AM.py index d7e880c..9800d83 100644 --- a/dpd/src/Model_AM.py +++ b/dpd/src/Model_AM.py @@ -8,9 +8,6 @@ import datetime import os import logging - -logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) - import numpy as np import matplotlib.pyplot as plt @@ -55,12 +52,12 @@ class Model_AM: self.plot = plot def _plot(self, tx_dpd, rx_received, coefs_am, coefs_am_new): - if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot: + if self.plot and self.c.plot_location is not None: tx_range, rx_est = calc_line(coefs_am, 0, 0.6) tx_range_new, rx_est_new = calc_line(coefs_am_new, 0, 0.6) dt = datetime.datetime.now().isoformat() - fig_path = logging_path + "/" + dt + "_Model_AM.svg" + fig_path = self.c.plot_location + "/" + dt + "_Model_AM.png" sub_rows = 1 sub_cols = 1 fig = plt.figure(figsize=(sub_cols * 6, sub_rows / 2. * 6)) diff --git a/dpd/src/Model_Lut.py b/dpd/src/Model_Lut.py index 6d4db52..e70fdb0 100644 --- a/dpd/src/Model_Lut.py +++ b/dpd/src/Model_Lut.py @@ -7,12 +7,8 @@ import os import logging - -logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) - import numpy as np - class Lut: """Implements a model that calculates lookup table coefficients""" diff --git a/dpd/src/Model_PM.py b/dpd/src/Model_PM.py index d4f8c00..3aafea0 100644 --- a/dpd/src/Model_PM.py +++ b/dpd/src/Model_PM.py @@ -8,9 +8,6 @@ import datetime import os import logging - -logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) - import numpy as np import matplotlib.pyplot as plt @@ -41,12 +38,12 @@ class Model_PM: self.plot = plot def _plot(self, tx_dpd, phase_diff, coefs_pm, coefs_pm_new): - if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot: + if self.plot and self.c.plot_location is not None: tx_range, phase_diff_est = self.calc_line(coefs_pm, 0, 0.6) tx_range_new, phase_diff_est_new = self.calc_line(coefs_pm_new, 0, 0.6) dt = datetime.datetime.now().isoformat() - fig_path = logging_path + "/" + dt + "_Model_PM.svg" + fig_path = self.c.plot_location + "/" + dt + "_Model_PM.png" sub_rows = 1 sub_cols = 1 fig = plt.figure(figsize=(sub_cols * 6, sub_rows / 2. * 6)) diff --git a/dpd/src/Model_Poly.py b/dpd/src/Model_Poly.py index ff15941..cdfd319 100644 --- a/dpd/src/Model_Poly.py +++ b/dpd/src/Model_Poly.py @@ -7,9 +7,6 @@ import os import logging - -logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) - import numpy as np import src.Model_AM as Model_AM @@ -54,13 +51,10 @@ class Poly: self.model_am = Model_AM.Model_AM(c, plot=self.plot) self.model_pm = Model_PM.Model_PM(c, plot=self.plot) - self.plot = c.MDL_plot - def reset_coefs(self): self.coefs_am = np.zeros(5, dtype=np.float32) self.coefs_am[0] = 1 self.coefs_pm = np.zeros(5, dtype=np.float32) - return self.coefs_am, self.coefs_pm def train(self, tx_abs, rx_abs, phase_diff, lr=None): """ diff --git a/dpd/src/RX_Agc.py b/dpd/src/RX_Agc.py index 670fbbb..f778dee 100644 --- a/dpd/src/RX_Agc.py +++ b/dpd/src/RX_Agc.py @@ -9,8 +9,6 @@ import datetime import os import logging import time -logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) - import numpy as np import matplotlib matplotlib.use('agg') @@ -70,11 +68,14 @@ class Agc: def plot_estimates(self): """Plots the estimate of for Max, Median, Mean for different number of samples.""" + if self.c.plot_location is None: + return + self.adapt.set_rxgain(self.min_rxgain) time.sleep(1) dt = datetime.datetime.now().isoformat() - fig_path = logging_path + "/" + dt + "_agc.svg" + fig_path = self.c.plot_location + "/" + dt + "_agc.png" fig, axs = plt.subplots(2, 2, figsize=(3*6,1*6)) axs = axs.ravel() diff --git a/dpd/src/Symbol_align.py b/dpd/src/Symbol_align.py index d921f25..2a17a65 100644 --- a/dpd/src/Symbol_align.py +++ b/dpd/src/Symbol_align.py @@ -8,12 +8,6 @@ import datetime import os import logging - -try: - logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) -except: - logging_path = "/tmp/" - import numpy as np import scipy import matplotlib @@ -75,9 +69,9 @@ class Symbol_align: offset = peaks[np.argmin([tx_product_avg[peak] for peak in peaks])] - if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot: + if self.plot and self.c.plot_location is not None: dt = datetime.datetime.now().isoformat() - fig_path = logging_path + "/" + dt + "_Symbol_align.svg" + fig_path = self.c.plot_location + "/" + dt + "_Symbol_align.png" fig = plt.figure(figsize=(9, 9)) diff --git a/dpd/src/TX_Agc.py b/dpd/src/TX_Agc.py index 3c804fa..309193d 100644 --- a/dpd/src/TX_Agc.py +++ b/dpd/src/TX_Agc.py @@ -9,9 +9,6 @@ import datetime import os import logging import time - -logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) - import numpy as np import matplotlib diff --git a/dpd/src/phase_align.py b/dpd/src/phase_align.py index 68c216d..8654333 100644 --- a/dpd/src/phase_align.py +++ b/dpd/src/phase_align.py @@ -7,8 +7,6 @@ import datetime import os import logging -logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) - import numpy as np import matplotlib.pyplot as plt @@ -24,9 +22,9 @@ def phase_align(sig, ref_sig, plot=False): real_diffs = np.cos(angle_diff) imag_diffs = np.sin(angle_diff) - if logging.getLogger().getEffectiveLevel() == logging.DEBUG and plot: + if plot and self.c.plot_location is not None: dt = datetime.datetime.now().isoformat() - fig_path = logging_path + "/" + dt + "_phase_align.svg" + fig_path = self.c.plot_location + "/" + dt + "_phase_align.png" plt.subplot(511) plt.hist(angle_diff, bins=60, label="Angle Diff") diff --git a/dpd/src/subsample_align.py b/dpd/src/subsample_align.py index 68f3591..20ae56b 100755 --- a/dpd/src/subsample_align.py +++ b/dpd/src/subsample_align.py @@ -7,14 +7,10 @@ import datetime import logging import os - -logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename) - import numpy as np from scipy import optimize import matplotlib.pyplot as plt - def gen_omega(length): if (length % 2) == 1: raise ValueError("Needs an even length array.") @@ -32,7 +28,7 @@ def gen_omega(length): return omega -def subsample_align(sig, ref_sig, plot=False): +def subsample_align(sig, ref_sig, plot_location=None): """Do subsample alignment for sig relative to the reference signal ref_sig. The delay between the two must be less than sample @@ -72,13 +68,13 @@ def subsample_align(sig, ref_sig, plot=False): if optim_result.success: best_tau = optim_result.x - if plot: + if plot_location is not None: corr = np.vectorize(correlate_for_delay) ixs = np.linspace(-1, 1, 100) taus = corr(ixs) dt = datetime.datetime.now().isoformat() - tau_path = (logging_path + "/" + dt + "_tau.svg") + tau_path = (plot_location + "/" + dt + "_tau.png") plt.plot(ixs, taus) plt.title("Subsample correlation, minimum is best: {}".format(best_tau)) plt.savefig(tau_path) |