aboutsummaryrefslogtreecommitdiffstats
path: root/python/dpd/Model_Poly.py
diff options
context:
space:
mode:
authorMatthias P. Braendli <matthias.braendli@mpb.li>2018-12-19 16:11:58 +0100
committerMatthias P. Braendli <matthias.braendli@mpb.li>2018-12-19 16:12:19 +0100
commitf4ca82137e850e30d31e7008b34800d8b2699e5d (patch)
treeff19ad63f6ddf8a4f62b173c5955b2711646f123 /python/dpd/Model_Poly.py
parent9d2c85f7a2a23fcf9ce3c842d86227afed43a153 (diff)
downloaddabmod-f4ca82137e850e30d31e7008b34800d8b2699e5d.tar.gz
dabmod-f4ca82137e850e30d31e7008b34800d8b2699e5d.tar.bz2
dabmod-f4ca82137e850e30d31e7008b34800d8b2699e5d.zip
DPD: Merge Model_PM and _AM into _Poly
Diffstat (limited to 'python/dpd/Model_Poly.py')
-rw-r--r--python/dpd/Model_Poly.py146
1 files changed, 127 insertions, 19 deletions
diff --git a/python/dpd/Model_Poly.py b/python/dpd/Model_Poly.py
index ca39492..5722531 100644
--- a/python/dpd/Model_Poly.py
+++ b/python/dpd/Model_Poly.py
@@ -8,15 +8,13 @@
import os
import logging
import numpy as np
+import matplotlib.pyplot as plt
-import dpd.Model_AM as Model_AM
-import dpd.Model_PM as Model_PM
-
-
-def assert_np_float32(x):
- assert isinstance(x, np.ndarray)
- assert x.dtype == np.float32
- assert x.flags.contiguous
+def assert_np_float32(array):
+ assert isinstance(array, np.ndarray), type(array)
+ assert array.dtype == np.float32, array.dtype
+ assert array.flags.contiguous
+ assert not any(np.isnan(array))
def _check_input_get_next_coefs(tx_abs, rx_abs, phase_diff):
@@ -44,12 +42,73 @@ class Poly:
self.reset_coefs()
- self.model_am = Model_AM.Model_AM(c)
- self.model_pm = Model_PM.Model_PM(c)
-
def plot(self, am_plot_location, pm_plot_location, title):
- self.model_am.plot(am_plot_location, title)
- self.model_pm.plot(pm_plot_location, title)
+ if self._am_plot_data is not None:
+ tx_dpd, rx_received, coefs_am, coefs_am_new = self._am_plot_data
+
+ tx_range, rx_est = self._am_calc_line(coefs_am, 0, 0.6)
+ tx_range_new, rx_est_new = self._am_calc_line(coefs_am_new, 0, 0.6)
+
+ sub_rows = 1
+ sub_cols = 1
+ fig = plt.figure(figsize=(sub_cols * 6, sub_rows / 2. * 6))
+ i_sub = 0
+
+ i_sub += 1
+ ax = plt.subplot(sub_rows, sub_cols, i_sub)
+ ax.plot(tx_range, rx_est,
+ label="Estimated TX",
+ alpha=0.3,
+ color="gray")
+ ax.plot(tx_range_new, rx_est_new,
+ label="New Estimated TX",
+ color="red")
+ ax.scatter(tx_dpd, rx_received,
+ label="Binned Data",
+ color="blue",
+ s=1)
+ ax.set_title("Model AM {}".format(title))
+ ax.set_xlabel("TX Amplitude")
+ ax.set_ylabel("RX Amplitude")
+ ax.set_xlim(-0.5, 1.5)
+ ax.legend(loc=4)
+
+ fig.tight_layout()
+ fig.savefig(am_plot_location)
+ plt.close(fig)
+
+ if self._pm_plot_data is not None:
+ tx_dpd, phase_diff, coefs_pm, coefs_pm_new = self._pm_plot_data
+
+ tx_range, phase_diff_est = self._pm_calc_line(coefs_pm, 0, 0.6)
+ tx_range_new, phase_diff_est_new = self._pm_calc_line(coefs_pm_new, 0, 0.6)
+
+ sub_rows = 1
+ sub_cols = 1
+ fig = plt.figure(figsize=(sub_cols * 6, sub_rows / 2. * 6))
+ i_sub = 0
+
+ i_sub += 1
+ ax = plt.subplot(sub_rows, sub_cols, i_sub)
+ ax.plot(tx_range, phase_diff_est,
+ label="Estimated Phase Diff",
+ alpha=0.3,
+ color="gray")
+ ax.plot(tx_range_new, phase_diff_est_new,
+ label="New Estimated Phase Diff",
+ color="red")
+ ax.scatter(tx_dpd, phase_diff,
+ label="Binned Data",
+ color="blue",
+ s=1)
+ ax.set_title("Model PM {}".format(title))
+ ax.set_xlabel("TX Amplitude")
+ ax.set_ylabel("Phase DIff")
+ ax.legend(loc=4)
+
+ fig.tight_layout()
+ fig.savefig(pm_plot_location)
+ plt.close(fig)
def reset_coefs(self):
self.coefs_am = np.zeros(5, dtype=np.float32)
@@ -65,12 +124,8 @@ class Poly:
"""
_check_input_get_next_coefs(tx_abs, rx_abs, phase_diff)
- if not lr is None:
- self.model_am.learning_rate_am = lr
- self.model_pm.learning_rate_pm = lr
-
- coefs_am_new = self.model_am.get_next_coefs(tx_abs, rx_abs, self.coefs_am)
- coefs_pm_new = self.model_pm.get_next_coefs(tx_abs, phase_diff, self.coefs_pm)
+ coefs_am_new = self._am_get_next_coefs(tx_abs, rx_abs, self.coefs_am)
+ coefs_pm_new = self._pm_get_next_coefs(tx_abs, phase_diff, self.coefs_pm)
self.coefs_am = self.coefs_am + (coefs_am_new - self.coefs_am) * self.learning_rate_am
self.coefs_pm = self.coefs_pm + (coefs_pm_new - self.coefs_pm) * self.learning_rate_pm
@@ -78,9 +133,62 @@ class Poly:
def get_dpd_data(self):
return "poly", self.coefs_am, self.coefs_pm
+ def _am_calc_line(self, coefs, min_amp, max_amp):
+ rx_range = np.linspace(min_amp, max_amp)
+ tx_est = np.sum(self._am_poly(rx_range) * coefs, axis=1)
+ return tx_est, rx_range
+
+ def _am_poly(self, sig):
+ return np.array([sig ** i for i in range(1, 6)]).T
+
+ def _am_fit_poly(self, tx_abs, rx_abs):
+ return np.linalg.lstsq(self._am_poly(rx_abs), tx_abs, rcond=None)[0]
+
+ def _am_get_next_coefs(self, tx_dpd, rx_received, coefs_am):
+ """Calculate the next AM/AM coefficients using the extracted
+ statistic of TX and RX amplitude"""
+
+ coefs_am_new = self._am_fit_poly(tx_dpd, rx_received)
+ coefs_am_new = coefs_am + \
+ self.learning_rate_am * (coefs_am_new - coefs_am)
+
+ self._am_plot_data = (tx_dpd, rx_received, coefs_am, coefs_am_new)
+
+ return coefs_am_new
+
+ def _pm_poly(self, sig):
+ return np.array([sig ** i for i in range(0, 5)]).T
+
+ def _pm_calc_line(self, coefs, min_amp, max_amp):
+ tx_range = np.linspace(min_amp, max_amp)
+ phase_diff = np.sum(self._pm_poly(tx_range) * coefs, axis=1)
+ return tx_range, phase_diff
+
+ def _discard_small_values(self, tx_dpd, phase_diff):
+ """ Assumes that the phase for small tx amplitudes is zero"""
+ mask = tx_dpd < self.c.MPM_tx_min
+ phase_diff[mask] = 0
+ return tx_dpd, phase_diff
+
+ def _pm_fit_poly(self, tx_abs, phase_diff):
+ return np.linalg.lstsq(self._pm_poly(tx_abs), phase_diff, rcond=None)[0]
+
+ def _pm_get_next_coefs(self, tx_dpd, phase_diff, coefs_pm):
+ """Calculate the next AM/PM coefficients using the extracted
+ statistic of TX amplitude and phase difference"""
+ tx_dpd, phase_diff = self._discard_small_values(tx_dpd, phase_diff)
+
+ coefs_pm_new = self._pm_fit_poly(tx_dpd, phase_diff)
+
+ coefs_pm_new = coefs_pm + self.learning_rate_pm * (coefs_pm_new - coefs_pm)
+ self._pm_plot_data = (tx_dpd, phase_diff, coefs_pm, coefs_pm_new)
+
+ return coefs_pm_new
+
# The MIT License (MIT)
#
# Copyright (c) 2017 Andreas Steger
+# Copyright (c) 2018 Matthias P. Brandli
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal