From 9954be2a2d94c6d7ebed8b36364d02e97084b9f2 Mon Sep 17 00:00:00 2001 From: andreas128 Date: Wed, 27 Sep 2017 20:59:17 +0200 Subject: Change fixed learning rate and number of measurements to heuristic --- dpd/main.py | 14 +++++++++----- dpd/src/Const.py | 2 +- dpd/src/Heuristics.py | 28 ++++++++++++++++++++++++++++ dpd/src/Model_AM.py | 15 ++++++++++----- dpd/src/Model_PM.py | 4 +++- dpd/src/Model_Poly.py | 7 ++++++- 6 files changed, 57 insertions(+), 13 deletions(-) create mode 100644 dpd/src/Heuristics.py diff --git a/dpd/main.py b/dpd/main.py index d71fd2d..d2b7297 100755 --- a/dpd/main.py +++ b/dpd/main.py @@ -13,6 +13,7 @@ predistortion module of ODR-DabMod.""" import datetime import os import matplotlib + matplotlib.use('GTKAgg') import logging @@ -49,6 +50,7 @@ from src.Const import Const from src.MER import MER from src.Measure_Shoulders import Measure_Shoulders import argparse +import src.Heuristics as Heur parser = argparse.ArgumentParser( description="DPD Computation Engine for ODR-DabMod") @@ -189,7 +191,8 @@ while i < num_iter: # Extract usable data from measurement tx, rx, phase_diff, n_per_bin = extStat.extract(txframe_aligned, rxframe_aligned) - if extStat.n_meas >= c.n_meas: + n_meas = Heur.get_n_meas(i) + if extStat.n_meas >= n_meas: # Use as many measurements nr of runs state = 'model' else: state = 'measure' @@ -197,7 +200,8 @@ while i < num_iter: # Model elif state == 'model': # Calculate new model parameters and delete old measurements - model.train(tx, rx, phase_diff) + lr = Heur.get_learning_rate(i) + model.train(tx, rx, phase_diff, lr=lr) dpddata = model.get_dpd_data() extStat = ExtractStatistic.ExtractStatistic(c) state = 'adapt' @@ -215,9 +219,9 @@ while i < num_iter: # Collect logging data off = SA.calc_offset(txframe_aligned) - tx_mer = MER.calc_mer(txframe_aligned[off:off+c.T_U], debug_name='TX') - rx_mer = MER.calc_mer(rxframe_aligned[off:off+c.T_U], debug_name='RX') - mse = np.mean(np.abs((txframe_aligned - rxframe_aligned)**2)) + tx_mer = MER.calc_mer(txframe_aligned[off:off + c.T_U], debug_name='TX') + rx_mer = MER.calc_mer(rxframe_aligned[off:off + c.T_U], debug_name='RX') + mse = np.mean(np.abs((txframe_aligned - rxframe_aligned) ** 2)) tx_gain = adapt.get_txgain() rx_gain = adapt.get_rxgain() digital_gain = adapt.get_digital_gain() diff --git a/dpd/src/Const.py b/dpd/src/Const.py index 2f9e151..bf46796 100644 --- a/dpd/src/Const.py +++ b/dpd/src/Const.py @@ -58,7 +58,7 @@ class Const: self.RAGC_rx_median_target = self.TAGC_tx_median_target # Constants for Model - self.MDL_plot = False + self.MDL_plot = True # Constants for MER self.MER_plot = False diff --git a/dpd/src/Heuristics.py b/dpd/src/Heuristics.py new file mode 100644 index 0000000..b6ec37f --- /dev/null +++ b/dpd/src/Heuristics.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# +# DPD Calculation Engine, heuristics we use to tune the parameters +# +# http://www.opendigitalradio.org +# Licence: The MIT License, see notice at the end of this file + +import numpy as np + +def get_learning_rate(idx_run): + idx_max = 10.0 + lr_min = 0.05 + lr_max = 1 + lr_delta = lr_max - lr_min + idx_run = min(idx_run, idx_max) + learning_rate = lr_max - lr_delta * idx_run/idx_max + return learning_rate + +def get_n_meas(idx_run): + idx_max = 10.0 + n_meas_min = 5 + n_meas_max = 50 + n_meas_delta = n_meas_max - n_meas_min + idx_run = min(idx_run, idx_max) + learning_rate = n_meas_delta * idx_run/idx_max + n_meas_min + return int(np.round(learning_rate)) + + diff --git a/dpd/src/Model_AM.py b/dpd/src/Model_AM.py index 4b88e08..c6f7903 100644 --- a/dpd/src/Model_AM.py +++ b/dpd/src/Model_AM.py @@ -17,12 +17,12 @@ import matplotlib.pyplot as plt def check_input_get_next_coefs(tx_dpd, rx_received): is_float32 = lambda x: (isinstance(x, np.ndarray) and - x.dtype == np.float32 and - x.flags.contiguous) + x.dtype == np.float32 and + x.flags.contiguous) assert is_float32(tx_dpd), \ - "tx_dpd is not float32 but {}".format(tx_dpd[0].dtype) + "tx_dpd is not float32 but {}".format(tx_dpd[0].dtype) assert is_float32(rx_received), \ - "rx_received is not float32 but {}".format(tx_dpd[0].dtype) + "rx_received is not float32 but {}".format(tx_dpd[0].dtype) class Model_AM: @@ -31,7 +31,7 @@ class Model_AM: def __init__(self, c, - learning_rate_am=0.1, + learning_rate_am=1, plot=False): self.c = c @@ -66,6 +66,8 @@ class Model_AM: ax.set_title("Model_AM") ax.set_xlabel("TX Amplitude") ax.set_ylabel("RX Amplitude") + xlim = ax.get_xlim() + ax.set_xlim(max(xlim[0], -1), min(xlim[1], 2)) ax.legend(loc=4) fig.tight_layout() @@ -87,6 +89,9 @@ class Model_AM: check_input_get_next_coefs(tx_dpd, rx_received) coefs_am_new = self.fit_poly(tx_dpd, rx_received) + coefs_am_new = coefs_am + \ + self.learning_rate_am * (coefs_am_new - coefs_am) + self._plot(tx_dpd, rx_received, coefs_am, coefs_am_new) return coefs_am_new diff --git a/dpd/src/Model_PM.py b/dpd/src/Model_PM.py index 75fb055..e0fcb55 100644 --- a/dpd/src/Model_PM.py +++ b/dpd/src/Model_PM.py @@ -34,7 +34,7 @@ class Model_PM: def __init__(self, c, - learning_rate_pm=0.1, + learning_rate_pm=1, plot=False): self.c = c @@ -97,6 +97,8 @@ class Model_PM: check_input_get_next_coefs(tx_dpd, phase_diff) coefs_pm_new = self.fit_poly(tx_dpd, phase_diff) + + coefs_pm_new = coefs_pm + self.learning_rate_pm * (coefs_pm_new - coefs_pm) self._plot(tx_dpd, phase_diff, coefs_pm, coefs_pm_new) return coefs_pm_new diff --git a/dpd/src/Model_Poly.py b/dpd/src/Model_Poly.py index 44e0483..16881ad 100644 --- a/dpd/src/Model_Poly.py +++ b/dpd/src/Model_Poly.py @@ -62,14 +62,19 @@ class Poly: self.coefs_pm = np.zeros(5, dtype=np.float32) return self.coefs_am, self.coefs_pm - def train(self, tx_abs, rx_abs, phase_diff): + def train(self, tx_abs, rx_abs, phase_diff, lr=None): """ :type tx_abs: np.ndarray :type rx_abs: np.ndarray :type phase_diff: np.ndarray + :type lr: float """ _check_input_get_next_coefs(tx_abs, rx_abs, phase_diff) + if not lr is None: + self.model_am.learning_rate_am = lr + self.model_pm.learning_rate_pm = lr + coefs_am_new = self.model_am.get_next_coefs(tx_abs, rx_abs, self.coefs_am) coefs_pm_new = self.model_pm.get_next_coefs(tx_abs, phase_diff, self.coefs_pm) -- cgit v1.2.3