summaryrefslogtreecommitdiffstats
path: root/dpd/src
diff options
context:
space:
mode:
Diffstat (limited to 'dpd/src')
-rw-r--r--dpd/src/Adapt.py35
-rw-r--r--dpd/src/Dab_Util.py26
-rw-r--r--dpd/src/ExtractStatistic.py8
-rw-r--r--dpd/src/GlobalConfig.py (renamed from dpd/src/Const.py)39
-rw-r--r--dpd/src/MER.py13
-rw-r--r--dpd/src/Measure.py7
-rw-r--r--dpd/src/Measure_Shoulders.py8
-rw-r--r--dpd/src/Model.py6
-rw-r--r--dpd/src/Model_AM.py7
-rw-r--r--dpd/src/Model_Lut.py4
-rw-r--r--dpd/src/Model_PM.py7
-rw-r--r--dpd/src/Model_Poly.py6
-rw-r--r--dpd/src/RX_Agc.py7
-rw-r--r--dpd/src/Symbol_align.py10
-rw-r--r--dpd/src/TX_Agc.py3
-rw-r--r--dpd/src/phase_align.py6
-rwxr-xr-xdpd/src/subsample_align.py10
17 files changed, 92 insertions, 110 deletions
diff --git a/dpd/src/Adapt.py b/dpd/src/Adapt.py
index 7e19a2c..329ee20 100644
--- a/dpd/src/Adapt.py
+++ b/dpd/src/Adapt.py
@@ -16,8 +16,6 @@ import os
import datetime
import pickle
-logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename)
-
LUT_LEN = 32
FORMAT_POLY = 1
FORMAT_LUT = 2
@@ -44,6 +42,19 @@ def _write_lut_file(scalefactor, lut, path):
f.write("{}\n{}\n".format(coef.real, coef.imag))
f.close()
+def dpddata_to_str(dpddata):
+ if dpddata[0] == "poly":
+ coefs_am = dpddata[1]
+ coefs_pm = dpddata[2]
+ return "dpd_coefs_am {}, dpd_coefs_pm {}".format(
+ coefs_am, coefs_pm)
+ elif dpddata[0] == "lut":
+ scalefactor = dpddata[1]
+ lut = dpddata[2]
+ return "LUT scalefactor {}, LUT {}".format(
+ scalefactor, lut)
+ else:
+ raise ValueError("Unknown dpddata type {}".format(dpddata[0]))
class Adapt:
"""Uses the ZMQ remote control to change parameters of the DabMod
@@ -55,8 +66,9 @@ class Adapt:
ZMQ remote control.
"""
- def __init__(self, port, coef_path):
+ def __init__(self, config, port, coef_path):
logging.debug("Instantiate Adapt object")
+ self.c = config
self.port = port
self.coef_path = coef_path
self.host = "localhost"
@@ -226,27 +238,30 @@ class Adapt:
"""Backup current settings to a file"""
dt = datetime.datetime.now().isoformat()
if path is None:
- path = logging_path + "/" + dt + "_adapt.pkl"
+ if self.c.plot_location is not None:
+ path = self.c.plot_location + "/" + dt + "_adapt.pkl"
+ else:
+ raise Exception("Cannot dump Adapt without either plot_location or path set")
d = {
"txgain": self.get_txgain(),
"rxgain": self.get_rxgain(),
"digital_gain": self.get_digital_gain(),
"predistorter": self.get_predistorter()
}
- with open(path, "w") as f:
+ with open(path, "wb") as f:
pickle.dump(d, f)
return path
def load(self, path):
"""Restore settings from a file"""
- with open(path, "r") as f:
+ with open(path, "rb") as f:
d = pickle.load(f)
- self.set_txgain(d["txgain"])
- self.set_digital_gain(d["digital_gain"])
- self.set_rxgain(d["rxgain"])
- self.set_predistorter(d["predistorter"])
+ self.set_txgain(d["txgain"])
+ self.set_digital_gain(d["digital_gain"])
+ self.set_rxgain(d["rxgain"])
+ self.set_predistorter(d["predistorter"])
# The MIT License (MIT)
#
diff --git a/dpd/src/Dab_Util.py b/dpd/src/Dab_Util.py
index 2021f38..bc89a39 100644
--- a/dpd/src/Dab_Util.py
+++ b/dpd/src/Dab_Util.py
@@ -8,9 +8,6 @@
import datetime
import os
import logging
-
-logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename)
-
import numpy as np
import matplotlib
@@ -33,10 +30,11 @@ class Dab_Util:
complex IQ samples of a DAB signal
"""
- def __init__(self, sample_rate, plot=False):
+ def __init__(self, config, sample_rate, plot=False):
"""
:param sample_rate: sample rate [sample/sec] to use for calculations
"""
+ self.c = config
self.sample_rate = sample_rate
self.dab_bandwidth = 1536000 # Bandwidth of a dab signal
self.frame_ms = 96 # Duration of a Dab frame
@@ -53,9 +51,9 @@ class Dab_Util:
off = sig_rec.shape[0]
c = np.abs(signal.correlate(sig_orig, sig_rec))
- if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot:
+ if self.plot and self.c.plot_location is not None:
dt = datetime.datetime.now().isoformat()
- corr_path = (logging_path + "/" + dt + "_tx_rx_corr.svg")
+ corr_path = self.c.plot_location + "/" + dt + "_tx_rx_corr.png"
plt.plot(c, label="corr")
plt.legend()
plt.savefig(corr_path)
@@ -107,9 +105,9 @@ class Dab_Util:
Returns an aligned version of sig_tx and sig_rx by cropping and subsample alignment
"""
- if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot:
+ if self.plot and self.c.plot_location is not None:
dt = datetime.datetime.now().isoformat()
- fig_path = logging_path + "/" + dt + "_sync_raw.svg"
+ fig_path = self.c.plot_location + "/" + dt + "_sync_raw.png"
fig, axs = plt.subplots(2)
axs[0].plot(np.abs(sig_tx[:128]), label="TX Frame")
@@ -151,9 +149,9 @@ class Dab_Util:
sig_tx = sig_tx[:-1]
sig_rx = sig_rx[:-1]
- if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot:
+ if self.plot and self.c.plot_location is not None:
dt = datetime.datetime.now().isoformat()
- fig_path = logging_path + "/" + dt + "_sync_sample_aligned.svg"
+ fig_path = self.c.plot_location + "/" + dt + "_sync_sample_aligned.png"
fig, axs = plt.subplots(2)
axs[0].plot(np.abs(sig_tx[:128]), label="TX Frame")
@@ -175,9 +173,9 @@ class Dab_Util:
sig_rx = sa.subsample_align(sig_rx, sig_tx)
- if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot:
+ if self.plot and self.c.plot_location is not None:
dt = datetime.datetime.now().isoformat()
- fig_path = logging_path + "/" + dt + "_sync_subsample_aligned.svg"
+ fig_path = self.c.plot_location + "/" + dt + "_sync_subsample_aligned.png"
fig, axs = plt.subplots(2)
axs[0].plot(np.abs(sig_tx[:128]), label="TX Frame")
@@ -199,9 +197,9 @@ class Dab_Util:
sig_rx = pa.phase_align(sig_rx, sig_tx)
- if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot:
+ if self.plot and self.c.plot_location is not None:
dt = datetime.datetime.now().isoformat()
- fig_path = logging_path + "/" + dt + "_sync_phase_aligned.svg"
+ fig_path = self.c.plot_location + "/" + dt + "_sync_phase_aligned.png"
fig, axs = plt.subplots(2)
axs[0].plot(np.abs(sig_tx[:128]), label="TX Frame")
diff --git a/dpd/src/ExtractStatistic.py b/dpd/src/ExtractStatistic.py
index d27cd77..639513a 100644
--- a/dpd/src/ExtractStatistic.py
+++ b/dpd/src/ExtractStatistic.py
@@ -8,13 +8,10 @@
import numpy as np
import matplotlib.pyplot as plt
-
import datetime
import os
import logging
-logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename)
-
def _check_input_extract(tx_dpd, rx_received):
# Check data type
@@ -64,10 +61,9 @@ class ExtractStatistic:
self.plot = c.ES_plot
def _plot_and_log(self, tx_values, rx_values, phase_diffs_values, phase_diffs_values_lists):
- if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot:
-
+ if self.plot and self.c.plot_location is not None:
dt = datetime.datetime.now().isoformat()
- fig_path = logging_path + "/" + dt + "_ExtractStatistic.png"
+ fig_path = self.c.plot_location + "/" + dt + "_ExtractStatistic.png"
sub_rows = 3
sub_cols = 1
fig = plt.figure(figsize=(sub_cols * 6, sub_rows / 2. * 6))
diff --git a/dpd/src/Const.py b/dpd/src/GlobalConfig.py
index 6c9bafa..b84b9d7 100644
--- a/dpd/src/Const.py
+++ b/dpd/src/GlobalConfig.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
#
-# DPD Computation Engine, constants.
+# DPD Computation Engine, constants and global configuration
#
# Source for DAB standard: etsi_EN_300_401_v010401p p145
#
@@ -9,19 +9,20 @@
import numpy as np
+class GlobalConfig:
+ def __init__(self, cli_args, plot_location):
+ self.sample_rate = cli_args.samplerate
+ assert self.sample_rate == 8192000 # By now only constants for 8192000
-class Const:
- def __init__(self, sample_rate, target_median, plot):
- assert sample_rate == 8192000 # By now only constants for 8192000
- self.sample_rate = sample_rate
+ self.plot_location = plot_location
# DAB frame
# Time domain
- self.T_F = sample_rate / 2048000 * 196608 # Transmission frame duration
- self.T_NULL = sample_rate / 2048000 * 2656 # Null symbol duration
- self.T_S = sample_rate / 2048000 * 2552 # Duration of OFDM symbols of indices l = 1, 2, 3,... L;
- self.T_U = sample_rate / 2048000 * 2048 # Inverse of carrier spacing
- self.T_C = sample_rate / 2048000 * 504 # Duration of cyclic prefix
+ self.T_F = self.sample_rate / 2048000 * 196608 # Transmission frame duration
+ self.T_NULL = self.sample_rate / 2048000 * 2656 # Null symbol duration
+ self.T_S = self.sample_rate / 2048000 * 2552 # Duration of OFDM symbols of indices l = 1, 2, 3,... L;
+ self.T_U = self.sample_rate / 2048000 * 2048 # Inverse of carrier spacing
+ self.T_C = self.sample_rate / 2048000 * 504 # Duration of cyclic prefix
# Frequency Domain
# example: np.delete(fft[3328:4865], 768)
@@ -34,10 +35,10 @@ class Const:
# time per sample = 1 / sample_rate
# frequency per bin = 1kHz
# phase difference per sample offset = delta_t * 2 * pi * delta_freq
- self.phase_offset_per_sample = 1. / sample_rate * 2 * np.pi * 1000
+ self.phase_offset_per_sample = 1. / self.sample_rate * 2 * np.pi * 1000
# Constants for ExtractStatistic
- self.ES_plot = plot
+ self.ES_plot = cli_args.plot
self.ES_start = 0.0
self.ES_end = 1.0
self.ES_n_bins = 64 # Number of bins between ES_start and ES_end
@@ -45,7 +46,7 @@ class Const:
# Constants for Measure_Shoulder
self.MS_enable = False
- self.MS_plot = plot
+ self.MS_plot = cli_args.plot
meas_offset = 976 # Offset from center frequency to measure shoulder [kHz]
meas_width = 100 # Size of frequency delta to measure shoulder [kHz]
@@ -63,24 +64,24 @@ class Const:
self.MS_n_proc = 4
# Constants for MER
- self.MER_plot = plot
+ self.MER_plot = cli_args.plot
# Constants for Model
- self.MDL_plot = True or plot # Override default
+ self.MDL_plot = cli_args.plot
# Constants for Model_PM
# Set all phase offsets to zero for TX amplitude < MPM_tx_min
self.MPM_tx_min = 0.1
# Constants for TX_Agc
- self.TAGC_max_txgain = 89 # USRP specific
- self.TAGC_tx_median_target = target_median
+ self.TAGC_max_txgain = 89 # USRP B200 specific
+ self.TAGC_tx_median_target = cli_args.target_median
self.TAGC_tx_median_max = self.TAGC_tx_median_target * 1.4
self.TAGC_tx_median_min = self.TAGC_tx_median_target / 1.4
# Constants for RX_AGC
- self.RAGC_min_rxgain = 25 # USRP specific
- self.RAGC_rx_median_target = self.TAGC_tx_median_target
+ self.RAGC_min_rxgain = 25 # USRP B200 specific
+ self.RAGC_rx_median_target = cli_args.target_median
# The MIT License (MIT)
#
diff --git a/dpd/src/MER.py b/dpd/src/MER.py
index f186261..693058d 100644
--- a/dpd/src/MER.py
+++ b/dpd/src/MER.py
@@ -8,11 +8,6 @@
import datetime
import os
import logging
-try:
- logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename)
-except:
- logging_path = "/tmp/"
-
import numpy as np
import matplotlib
matplotlib.use('agg')
@@ -76,9 +71,11 @@ class MER:
spectrum = self._calc_spectrum(tx)
- if self.plot:
+ if self.plot and self.c.plot_location is not None:
dt = datetime.datetime.now().isoformat()
- fig_path = logging_path + "/" + dt + "_MER" + debug_name + ".svg"
+ fig_path = self.c.plot_location + "/" + dt + "_MER" + debug_name + ".png"
+ else:
+ fig_path = None
MERs = []
for i, (x, y) in enumerate(self._split_in_carrier(
@@ -103,7 +100,7 @@ class MER:
ylim = ax.get_ylim()
ax.set_ylim(ylim[0] - (ylim[1] - ylim[0]) * 0.1, ylim[1])
- if self.plot:
+ if fig_path is not None:
plt.tight_layout()
plt.savefig(fig_path)
plt.show()
diff --git a/dpd/src/Measure.py b/dpd/src/Measure.py
index d4b1d9e..b7423c6 100644
--- a/dpd/src/Measure.py
+++ b/dpd/src/Measure.py
@@ -11,14 +11,13 @@ import struct
import numpy as np
import src.Dab_Util as DU
import os
-
import logging
-logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename)
class Measure:
"""Collect Measurement from DabMod"""
- def __init__(self, samplerate, port, num_samples_to_request):
+ def __init__(self, config, samplerate, port, num_samples_to_request):
logging.info("Instantiate Measure object")
+ self.c = config
self.samplerate = samplerate
self.sizeof_sample = 8 # complex floats
self.port = port
@@ -106,7 +105,7 @@ class Measure:
rx_median = np.median(np.abs(rxframe))
rxframe = rxframe / rx_median * np.median(np.abs(txframe))
- du = DU.Dab_Util(self.samplerate)
+ du = DU.Dab_Util(self.c, self.samplerate)
txframe_aligned, rxframe_aligned = du.subsample_align(txframe, rxframe)
logging.info(
diff --git a/dpd/src/Measure_Shoulders.py b/dpd/src/Measure_Shoulders.py
index c733dfd..fd90050 100644
--- a/dpd/src/Measure_Shoulders.py
+++ b/dpd/src/Measure_Shoulders.py
@@ -9,9 +9,6 @@ import datetime
import os
import logging
import multiprocessing
-
-logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename)
-
import numpy as np
import matplotlib.pyplot as plt
@@ -79,8 +76,11 @@ class Measure_Shoulders:
self.plot = c.MS_plot
def _plot(self, signal):
+ if self.c.plot_location is None:
+ return
+
dt = datetime.datetime.now().isoformat()
- fig_path = logging_path + "/" + dt + "_sync_subsample_aligned.svg"
+ fig_path = self.c.plot_location + "/" + dt + "_sync_subsample_aligned.png"
fft = calc_fft_db(signal, 100, self.c)
peak, idxs_peak = _calc_peak(fft, self.c)
diff --git a/dpd/src/Model.py b/dpd/src/Model.py
index 67feeb6..b2c303f 100644
--- a/dpd/src/Model.py
+++ b/dpd/src/Model.py
@@ -2,6 +2,12 @@
from src.Model_Poly import Poly
from src.Model_Lut import Lut
+def select_model_from_dpddata(dpddata):
+ if dpddata[0] == 'lut':
+ return Lut
+ elif dpddata[0] == 'poly':
+ return Poly
+
# The MIT License (MIT)
#
# Copyright (c) 2017 Andreas Steger
diff --git a/dpd/src/Model_AM.py b/dpd/src/Model_AM.py
index d7e880c..9800d83 100644
--- a/dpd/src/Model_AM.py
+++ b/dpd/src/Model_AM.py
@@ -8,9 +8,6 @@
import datetime
import os
import logging
-
-logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename)
-
import numpy as np
import matplotlib.pyplot as plt
@@ -55,12 +52,12 @@ class Model_AM:
self.plot = plot
def _plot(self, tx_dpd, rx_received, coefs_am, coefs_am_new):
- if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot:
+ if self.plot and self.c.plot_location is not None:
tx_range, rx_est = calc_line(coefs_am, 0, 0.6)
tx_range_new, rx_est_new = calc_line(coefs_am_new, 0, 0.6)
dt = datetime.datetime.now().isoformat()
- fig_path = logging_path + "/" + dt + "_Model_AM.svg"
+ fig_path = self.c.plot_location + "/" + dt + "_Model_AM.png"
sub_rows = 1
sub_cols = 1
fig = plt.figure(figsize=(sub_cols * 6, sub_rows / 2. * 6))
diff --git a/dpd/src/Model_Lut.py b/dpd/src/Model_Lut.py
index 6d4db52..e70fdb0 100644
--- a/dpd/src/Model_Lut.py
+++ b/dpd/src/Model_Lut.py
@@ -7,12 +7,8 @@
import os
import logging
-
-logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename)
-
import numpy as np
-
class Lut:
"""Implements a model that calculates lookup table coefficients"""
diff --git a/dpd/src/Model_PM.py b/dpd/src/Model_PM.py
index d4f8c00..3aafea0 100644
--- a/dpd/src/Model_PM.py
+++ b/dpd/src/Model_PM.py
@@ -8,9 +8,6 @@
import datetime
import os
import logging
-
-logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename)
-
import numpy as np
import matplotlib.pyplot as plt
@@ -41,12 +38,12 @@ class Model_PM:
self.plot = plot
def _plot(self, tx_dpd, phase_diff, coefs_pm, coefs_pm_new):
- if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot:
+ if self.plot and self.c.plot_location is not None:
tx_range, phase_diff_est = self.calc_line(coefs_pm, 0, 0.6)
tx_range_new, phase_diff_est_new = self.calc_line(coefs_pm_new, 0, 0.6)
dt = datetime.datetime.now().isoformat()
- fig_path = logging_path + "/" + dt + "_Model_PM.svg"
+ fig_path = self.c.plot_location + "/" + dt + "_Model_PM.png"
sub_rows = 1
sub_cols = 1
fig = plt.figure(figsize=(sub_cols * 6, sub_rows / 2. * 6))
diff --git a/dpd/src/Model_Poly.py b/dpd/src/Model_Poly.py
index ff15941..cdfd319 100644
--- a/dpd/src/Model_Poly.py
+++ b/dpd/src/Model_Poly.py
@@ -7,9 +7,6 @@
import os
import logging
-
-logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename)
-
import numpy as np
import src.Model_AM as Model_AM
@@ -54,13 +51,10 @@ class Poly:
self.model_am = Model_AM.Model_AM(c, plot=self.plot)
self.model_pm = Model_PM.Model_PM(c, plot=self.plot)
- self.plot = c.MDL_plot
-
def reset_coefs(self):
self.coefs_am = np.zeros(5, dtype=np.float32)
self.coefs_am[0] = 1
self.coefs_pm = np.zeros(5, dtype=np.float32)
- return self.coefs_am, self.coefs_pm
def train(self, tx_abs, rx_abs, phase_diff, lr=None):
"""
diff --git a/dpd/src/RX_Agc.py b/dpd/src/RX_Agc.py
index 670fbbb..f778dee 100644
--- a/dpd/src/RX_Agc.py
+++ b/dpd/src/RX_Agc.py
@@ -9,8 +9,6 @@ import datetime
import os
import logging
import time
-logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename)
-
import numpy as np
import matplotlib
matplotlib.use('agg')
@@ -70,11 +68,14 @@ class Agc:
def plot_estimates(self):
"""Plots the estimate of for Max, Median, Mean for different
number of samples."""
+ if self.c.plot_location is None:
+ return
+
self.adapt.set_rxgain(self.min_rxgain)
time.sleep(1)
dt = datetime.datetime.now().isoformat()
- fig_path = logging_path + "/" + dt + "_agc.svg"
+ fig_path = self.c.plot_location + "/" + dt + "_agc.png"
fig, axs = plt.subplots(2, 2, figsize=(3*6,1*6))
axs = axs.ravel()
diff --git a/dpd/src/Symbol_align.py b/dpd/src/Symbol_align.py
index d921f25..2a17a65 100644
--- a/dpd/src/Symbol_align.py
+++ b/dpd/src/Symbol_align.py
@@ -8,12 +8,6 @@
import datetime
import os
import logging
-
-try:
- logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename)
-except:
- logging_path = "/tmp/"
-
import numpy as np
import scipy
import matplotlib
@@ -75,9 +69,9 @@ class Symbol_align:
offset = peaks[np.argmin([tx_product_avg[peak] for peak in peaks])]
- if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot:
+ if self.plot and self.c.plot_location is not None:
dt = datetime.datetime.now().isoformat()
- fig_path = logging_path + "/" + dt + "_Symbol_align.svg"
+ fig_path = self.c.plot_location + "/" + dt + "_Symbol_align.png"
fig = plt.figure(figsize=(9, 9))
diff --git a/dpd/src/TX_Agc.py b/dpd/src/TX_Agc.py
index 3c804fa..309193d 100644
--- a/dpd/src/TX_Agc.py
+++ b/dpd/src/TX_Agc.py
@@ -9,9 +9,6 @@ import datetime
import os
import logging
import time
-
-logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename)
-
import numpy as np
import matplotlib
diff --git a/dpd/src/phase_align.py b/dpd/src/phase_align.py
index 68c216d..8654333 100644
--- a/dpd/src/phase_align.py
+++ b/dpd/src/phase_align.py
@@ -7,8 +7,6 @@
import datetime
import os
import logging
-logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename)
-
import numpy as np
import matplotlib.pyplot as plt
@@ -24,9 +22,9 @@ def phase_align(sig, ref_sig, plot=False):
real_diffs = np.cos(angle_diff)
imag_diffs = np.sin(angle_diff)
- if logging.getLogger().getEffectiveLevel() == logging.DEBUG and plot:
+ if plot and self.c.plot_location is not None:
dt = datetime.datetime.now().isoformat()
- fig_path = logging_path + "/" + dt + "_phase_align.svg"
+ fig_path = self.c.plot_location + "/" + dt + "_phase_align.png"
plt.subplot(511)
plt.hist(angle_diff, bins=60, label="Angle Diff")
diff --git a/dpd/src/subsample_align.py b/dpd/src/subsample_align.py
index 68f3591..20ae56b 100755
--- a/dpd/src/subsample_align.py
+++ b/dpd/src/subsample_align.py
@@ -7,14 +7,10 @@
import datetime
import logging
import os
-
-logging_path = os.path.dirname(logging.getLoggerClass().root.handlers[0].baseFilename)
-
import numpy as np
from scipy import optimize
import matplotlib.pyplot as plt
-
def gen_omega(length):
if (length % 2) == 1:
raise ValueError("Needs an even length array.")
@@ -32,7 +28,7 @@ def gen_omega(length):
return omega
-def subsample_align(sig, ref_sig, plot=False):
+def subsample_align(sig, ref_sig, plot_location=None):
"""Do subsample alignment for sig relative to the reference signal
ref_sig. The delay between the two must be less than sample
@@ -72,13 +68,13 @@ def subsample_align(sig, ref_sig, plot=False):
if optim_result.success:
best_tau = optim_result.x
- if plot:
+ if plot_location is not None:
corr = np.vectorize(correlate_for_delay)
ixs = np.linspace(-1, 1, 100)
taus = corr(ixs)
dt = datetime.datetime.now().isoformat()
- tau_path = (logging_path + "/" + dt + "_tau.svg")
+ tau_path = (plot_location + "/" + dt + "_tau.png")
plt.plot(ixs, taus)
plt.title("Subsample correlation, minimum is best: {}".format(best_tau))
plt.savefig(tau_path)