aboutsummaryrefslogtreecommitdiffstats
path: root/dpd
diff options
context:
space:
mode:
authorandreas128 <Andreas>2017-09-27 20:59:17 +0200
committerandreas128 <Andreas>2017-09-27 20:59:17 +0200
commit9954be2a2d94c6d7ebed8b36364d02e97084b9f2 (patch)
tree2173908a9b3103eaff8d42812accd167312232d8 /dpd
parent071088f747f5629f60b01017bdcff5161efb7ba5 (diff)
downloaddabmod-9954be2a2d94c6d7ebed8b36364d02e97084b9f2.tar.gz
dabmod-9954be2a2d94c6d7ebed8b36364d02e97084b9f2.tar.bz2
dabmod-9954be2a2d94c6d7ebed8b36364d02e97084b9f2.zip
Change fixed learning rate and number of measurements to heuristic
Diffstat (limited to 'dpd')
-rwxr-xr-xdpd/main.py14
-rw-r--r--dpd/src/Const.py2
-rw-r--r--dpd/src/Heuristics.py28
-rw-r--r--dpd/src/Model_AM.py15
-rw-r--r--dpd/src/Model_PM.py4
-rw-r--r--dpd/src/Model_Poly.py7
6 files changed, 57 insertions, 13 deletions
diff --git a/dpd/main.py b/dpd/main.py
index d71fd2d..d2b7297 100755
--- a/dpd/main.py
+++ b/dpd/main.py
@@ -13,6 +13,7 @@ predistortion module of ODR-DabMod."""
import datetime
import os
import matplotlib
+
matplotlib.use('GTKAgg')
import logging
@@ -49,6 +50,7 @@ from src.Const import Const
from src.MER import MER
from src.Measure_Shoulders import Measure_Shoulders
import argparse
+import src.Heuristics as Heur
parser = argparse.ArgumentParser(
description="DPD Computation Engine for ODR-DabMod")
@@ -189,7 +191,8 @@ while i < num_iter:
# Extract usable data from measurement
tx, rx, phase_diff, n_per_bin = extStat.extract(txframe_aligned, rxframe_aligned)
- if extStat.n_meas >= c.n_meas:
+ n_meas = Heur.get_n_meas(i)
+ if extStat.n_meas >= n_meas: # Use as many measurements nr of runs
state = 'model'
else:
state = 'measure'
@@ -197,7 +200,8 @@ while i < num_iter:
# Model
elif state == 'model':
# Calculate new model parameters and delete old measurements
- model.train(tx, rx, phase_diff)
+ lr = Heur.get_learning_rate(i)
+ model.train(tx, rx, phase_diff, lr=lr)
dpddata = model.get_dpd_data()
extStat = ExtractStatistic.ExtractStatistic(c)
state = 'adapt'
@@ -215,9 +219,9 @@ while i < num_iter:
# Collect logging data
off = SA.calc_offset(txframe_aligned)
- tx_mer = MER.calc_mer(txframe_aligned[off:off+c.T_U], debug_name='TX')
- rx_mer = MER.calc_mer(rxframe_aligned[off:off+c.T_U], debug_name='RX')
- mse = np.mean(np.abs((txframe_aligned - rxframe_aligned)**2))
+ tx_mer = MER.calc_mer(txframe_aligned[off:off + c.T_U], debug_name='TX')
+ rx_mer = MER.calc_mer(rxframe_aligned[off:off + c.T_U], debug_name='RX')
+ mse = np.mean(np.abs((txframe_aligned - rxframe_aligned) ** 2))
tx_gain = adapt.get_txgain()
rx_gain = adapt.get_rxgain()
digital_gain = adapt.get_digital_gain()
diff --git a/dpd/src/Const.py b/dpd/src/Const.py
index 2f9e151..bf46796 100644
--- a/dpd/src/Const.py
+++ b/dpd/src/Const.py
@@ -58,7 +58,7 @@ class Const:
self.RAGC_rx_median_target = self.TAGC_tx_median_target
# Constants for Model
- self.MDL_plot = False
+ self.MDL_plot = True
# Constants for MER
self.MER_plot = False
diff --git a/dpd/src/Heuristics.py b/dpd/src/Heuristics.py
new file mode 100644
index 0000000..b6ec37f
--- /dev/null
+++ b/dpd/src/Heuristics.py
@@ -0,0 +1,28 @@
+# -*- coding: utf-8 -*-
+#
+# DPD Calculation Engine, heuristics we use to tune the parameters
+#
+# http://www.opendigitalradio.org
+# Licence: The MIT License, see notice at the end of this file
+
+import numpy as np
+
+def get_learning_rate(idx_run):
+ idx_max = 10.0
+ lr_min = 0.05
+ lr_max = 1
+ lr_delta = lr_max - lr_min
+ idx_run = min(idx_run, idx_max)
+ learning_rate = lr_max - lr_delta * idx_run/idx_max
+ return learning_rate
+
+def get_n_meas(idx_run):
+ idx_max = 10.0
+ n_meas_min = 5
+ n_meas_max = 50
+ n_meas_delta = n_meas_max - n_meas_min
+ idx_run = min(idx_run, idx_max)
+ learning_rate = n_meas_delta * idx_run/idx_max + n_meas_min
+ return int(np.round(learning_rate))
+
+
diff --git a/dpd/src/Model_AM.py b/dpd/src/Model_AM.py
index 4b88e08..c6f7903 100644
--- a/dpd/src/Model_AM.py
+++ b/dpd/src/Model_AM.py
@@ -17,12 +17,12 @@ import matplotlib.pyplot as plt
def check_input_get_next_coefs(tx_dpd, rx_received):
is_float32 = lambda x: (isinstance(x, np.ndarray) and
- x.dtype == np.float32 and
- x.flags.contiguous)
+ x.dtype == np.float32 and
+ x.flags.contiguous)
assert is_float32(tx_dpd), \
- "tx_dpd is not float32 but {}".format(tx_dpd[0].dtype)
+ "tx_dpd is not float32 but {}".format(tx_dpd[0].dtype)
assert is_float32(rx_received), \
- "rx_received is not float32 but {}".format(tx_dpd[0].dtype)
+ "rx_received is not float32 but {}".format(tx_dpd[0].dtype)
class Model_AM:
@@ -31,7 +31,7 @@ class Model_AM:
def __init__(self,
c,
- learning_rate_am=0.1,
+ learning_rate_am=1,
plot=False):
self.c = c
@@ -66,6 +66,8 @@ class Model_AM:
ax.set_title("Model_AM")
ax.set_xlabel("TX Amplitude")
ax.set_ylabel("RX Amplitude")
+ xlim = ax.get_xlim()
+ ax.set_xlim(max(xlim[0], -1), min(xlim[1], 2))
ax.legend(loc=4)
fig.tight_layout()
@@ -87,6 +89,9 @@ class Model_AM:
check_input_get_next_coefs(tx_dpd, rx_received)
coefs_am_new = self.fit_poly(tx_dpd, rx_received)
+ coefs_am_new = coefs_am + \
+ self.learning_rate_am * (coefs_am_new - coefs_am)
+
self._plot(tx_dpd, rx_received, coefs_am, coefs_am_new)
return coefs_am_new
diff --git a/dpd/src/Model_PM.py b/dpd/src/Model_PM.py
index 75fb055..e0fcb55 100644
--- a/dpd/src/Model_PM.py
+++ b/dpd/src/Model_PM.py
@@ -34,7 +34,7 @@ class Model_PM:
def __init__(self,
c,
- learning_rate_pm=0.1,
+ learning_rate_pm=1,
plot=False):
self.c = c
@@ -97,6 +97,8 @@ class Model_PM:
check_input_get_next_coefs(tx_dpd, phase_diff)
coefs_pm_new = self.fit_poly(tx_dpd, phase_diff)
+
+ coefs_pm_new = coefs_pm + self.learning_rate_pm * (coefs_pm_new - coefs_pm)
self._plot(tx_dpd, phase_diff, coefs_pm, coefs_pm_new)
return coefs_pm_new
diff --git a/dpd/src/Model_Poly.py b/dpd/src/Model_Poly.py
index 44e0483..16881ad 100644
--- a/dpd/src/Model_Poly.py
+++ b/dpd/src/Model_Poly.py
@@ -62,14 +62,19 @@ class Poly:
self.coefs_pm = np.zeros(5, dtype=np.float32)
return self.coefs_am, self.coefs_pm
- def train(self, tx_abs, rx_abs, phase_diff):
+ def train(self, tx_abs, rx_abs, phase_diff, lr=None):
"""
:type tx_abs: np.ndarray
:type rx_abs: np.ndarray
:type phase_diff: np.ndarray
+ :type lr: float
"""
_check_input_get_next_coefs(tx_abs, rx_abs, phase_diff)
+ if not lr is None:
+ self.model_am.learning_rate_am = lr
+ self.model_pm.learning_rate_pm = lr
+
coefs_am_new = self.model_am.get_next_coefs(tx_abs, rx_abs, self.coefs_am)
coefs_pm_new = self.model_pm.get_next_coefs(tx_abs, phase_diff, self.coefs_pm)