aboutsummaryrefslogtreecommitdiffstats
path: root/dpd/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'dpd/main.py')
-rwxr-xr-xdpd/main.py100
1 files changed, 70 insertions, 30 deletions
diff --git a/dpd/main.py b/dpd/main.py
index de3453e..084ccd5 100755
--- a/dpd/main.py
+++ b/dpd/main.py
@@ -42,6 +42,7 @@ import numpy as np
import traceback
import src.Measure as Measure
import src.Model as Model
+import src.ExtractStatistic as ExtractStatistic
import src.Adapt as Adapt
import src.Agc as Agc
import src.TX_Agc as TX_Agc
@@ -52,19 +53,19 @@ import argparse
parser = argparse.ArgumentParser(
description="DPD Computation Engine for ODR-DabMod")
-parser.add_argument('--port', default='50055',
+parser.add_argument('--port', default=50055, type=int,
help='port of DPD server to connect to (default: 50055)',
required=False)
-parser.add_argument('--rc-port', default='9400',
+parser.add_argument('--rc-port', default=9400, type=int,
help='port of ODR-DabMod ZMQ Remote Control to connect to (default: 9400)',
required=False)
-parser.add_argument('--samplerate', default='8192000',
+parser.add_argument('--samplerate', default=8192000, type=int,
help='Sample rate',
required=False)
parser.add_argument('--coefs', default='poly.coef',
help='File with DPD coefficients, which will be read by ODR-DabMod',
required=False)
-parser.add_argument('--txgain', default=71,
+parser.add_argument('--txgain', default=73,
help='TX Gain',
required=False,
type=int)
@@ -76,10 +77,10 @@ parser.add_argument('--digital_gain', default=1,
help='Digital Gain',
required=False,
type=float)
-parser.add_argument('--samps', default='81920',
+parser.add_argument('--samps', default='81920', type=int,
help='Number of samples to request from ODR-DabMod',
required=False)
-parser.add_argument('-i', '--iterations', default='1',
+parser.add_argument('-i', '--iterations', default=1, type=int,
help='Number of iterations to run',
required=False)
parser.add_argument('-L', '--lut',
@@ -88,29 +89,29 @@ parser.add_argument('-L', '--lut',
cli_args = parser.parse_args()
-port = int(cli_args.port)
-port_rc = int(cli_args.rc_port)
+port = cli_args.port
+port_rc = cli_args.rc_port
coef_path = cli_args.coefs
digital_gain = cli_args.digital_gain
txgain = cli_args.txgain
rxgain = cli_args.rxgain
-num_req = int(cli_args.samps)
-samplerate = int(cli_args.samplerate)
-num_iter = int(cli_args.iterations)
+num_req = cli_args.samps
+samplerate = cli_args.samplerate
+num_iter = cli_args.iterations
SA = src.Symbol_align.Symbol_align(samplerate)
MER = src.MER.MER(samplerate)
c = src.const.const(samplerate)
meas = Measure.Measure(samplerate, port, num_req)
-
+extStat = ExtractStatistic.ExtractStatistic(c, plot=True)
adapt = Adapt.Adapt(port_rc, coef_path)
dpddata = adapt.get_predistorter()
if cli_args.lut:
- model = Model.LutModel(c, SA, MER, plot=True)
+ model = Model.Lut(c, plot=True)
else:
- model = Model.PolyModel(c, SA, MER, None, None, plot=True)
+ model = Model.Poly(c, plot=True)
adapt.set_predistorter(model.get_dpd_data())
adapt.set_digital_gain(digital_gain)
adapt.set_txgain(txgain)
@@ -120,7 +121,7 @@ tx_gain = adapt.get_txgain()
rx_gain = adapt.get_rxgain()
digital_gain = adapt.get_digital_gain()
-dpddata = adapt.get_coefs()
+dpddata = adapt.get_predistorter()
if dpddata[0] == "poly":
coefs_am = dpddata[1]
coefs_pm = dpddata[2]
@@ -148,23 +149,62 @@ tx_agc = TX_Agc.TX_Agc(adapt)
agc = Agc.Agc(meas, adapt)
agc.run()
-for i in range(num_iter):
+state = "measure"
+i = 0
+while i < num_iter:
try:
- txframe_aligned, tx_ts, rxframe_aligned, rx_ts, rx_median = meas.get_samples()
- logging.debug("tx_ts {}, rx_ts {}".format(tx_ts, rx_ts))
- assert tx_ts - rx_ts < 1e-5, "Time stamps do not match."
-
- if tx_agc.adapt_if_necessary(txframe_aligned):
- continue
-
- model.train(txframe_aligned, rxframe_aligned)
- adapt.set_predistorter(model.get_dpd_data())
+ # Measure
+ if state == "measure":
+ txframe_aligned, tx_ts, rxframe_aligned, rx_ts, rx_median = meas.get_samples()
+ tx, rx, phase_diff, n_per_bin = extStat.extract(txframe_aligned, rxframe_aligned)
+ n_use = int(len(n_per_bin) * 0.6)
+ tx = tx[:n_use]
+ rx = rx[:n_use]
+ phase_diff = phase_diff[:n_use]
+ if all(c.ES_n_per_bin == np.array(n_per_bin)[0:n_use]):
+ state = "model"
+ else:
+ state = "measure"
+
+ # Model
+ elif state == "model":
+ dpddata = model_poly.get_dpd_data(tx, rx, phase_diff)
+ del extStat
+ extStat = ExtractStatistic.ExtractStatistic(c, plot=True)
+ state = "adapt"
+
+ # Adapt
+ elif state == "adapt":
+ adapt.set_predistorter(dpddata)
+ state = "report"
+ i += 1
+
+ # Report
+ elif state == "report":
+ try:
+ off = SA.calc_offset(txframe_aligned)
+ tx_mer = MER.calc_mer(txframe_aligned[off:off+c.T_U], debug=True)
+ rx_mer = MER.calc_mer(rxframe_aligned[off:off+c.T_U], debug=True)
+ mse = np.mean(np.abs((txframe_aligned - rxframe_aligned)**2))
+
+ if dpddata[0] == "poly":
+ coefs_am = dpddata[1]
+ coefs_pm = dpddata[2]
+ logging.info("It {}: TX_MER {}, RX_MER {}," \
+ " MSE {}, coefs_am {}, coefs_pm {}".
+ format(i, tx_mer, rx_mer, mse, coefs_am, coefs_pm))
+ if dpddata[0] == "lut":
+ scalefactor = dpddata[1]
+ lut = dpddata[2]
+ logging.info("It {}: TX_MER {}, RX_MER {}," \
+ " MSE {}, LUT scalefactor {}, LUT {}".
+ format(i, tx_mer, rx_mer, mse, scalefactor, lut))
+ state = "measure"
+ except:
+ logging.warning("Iteration {}: Report failed.".format(i))
+ logging.warning(traceback.format_exc())
+ state = "measure"
- off = SA.calc_offset(txframe_aligned)
- tx_mer = MER.calc_mer(txframe_aligned[off:off + c.T_U])
- rx_mer = MER.calc_mer(rxframe_aligned[off:off + c.T_U], debug=True)
- logging.info("MER with lag in it. {}: TX {}, RX {}".
- format(i, tx_mer, rx_mer))
except Exception as e:
logging.warning("Iteration {} failed.".format(i))
logging.warning(traceback.format_exc())