From 9954be2a2d94c6d7ebed8b36364d02e97084b9f2 Mon Sep 17 00:00:00 2001 From: andreas128 Date: Wed, 27 Sep 2017 20:59:17 +0200 Subject: Change fixed learning rate and number of measurements to heuristic --- dpd/main.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) (limited to 'dpd/main.py') diff --git a/dpd/main.py b/dpd/main.py index d71fd2d..d2b7297 100755 --- a/dpd/main.py +++ b/dpd/main.py @@ -13,6 +13,7 @@ predistortion module of ODR-DabMod.""" import datetime import os import matplotlib + matplotlib.use('GTKAgg') import logging @@ -49,6 +50,7 @@ from src.Const import Const from src.MER import MER from src.Measure_Shoulders import Measure_Shoulders import argparse +import src.Heuristics as Heur parser = argparse.ArgumentParser( description="DPD Computation Engine for ODR-DabMod") @@ -189,7 +191,8 @@ while i < num_iter: # Extract usable data from measurement tx, rx, phase_diff, n_per_bin = extStat.extract(txframe_aligned, rxframe_aligned) - if extStat.n_meas >= c.n_meas: + n_meas = Heur.get_n_meas(i) + if extStat.n_meas >= n_meas: # Use as many measurements nr of runs state = 'model' else: state = 'measure' @@ -197,7 +200,8 @@ while i < num_iter: # Model elif state == 'model': # Calculate new model parameters and delete old measurements - model.train(tx, rx, phase_diff) + lr = Heur.get_learning_rate(i) + model.train(tx, rx, phase_diff, lr=lr) dpddata = model.get_dpd_data() extStat = ExtractStatistic.ExtractStatistic(c) state = 'adapt' @@ -215,9 +219,9 @@ while i < num_iter: # Collect logging data off = SA.calc_offset(txframe_aligned) - tx_mer = MER.calc_mer(txframe_aligned[off:off+c.T_U], debug_name='TX') - rx_mer = MER.calc_mer(rxframe_aligned[off:off+c.T_U], debug_name='RX') - mse = np.mean(np.abs((txframe_aligned - rxframe_aligned)**2)) + tx_mer = MER.calc_mer(txframe_aligned[off:off + c.T_U], debug_name='TX') + rx_mer = MER.calc_mer(rxframe_aligned[off:off + c.T_U], debug_name='RX') + mse = np.mean(np.abs((txframe_aligned - rxframe_aligned) ** 2)) tx_gain = adapt.get_txgain() rx_gain = adapt.get_rxgain() digital_gain = adapt.get_digital_gain() -- cgit v1.2.3