aboutsummaryrefslogtreecommitdiffstats
path: root/dpd/src/Model_AM.py
diff options
context:
space:
mode:
Diffstat (limited to 'dpd/src/Model_AM.py')
-rw-r--r--dpd/src/Model_AM.py48
1 files changed, 26 insertions, 22 deletions
diff --git a/dpd/src/Model_AM.py b/dpd/src/Model_AM.py
index 85f6495..cdc3de1 100644
--- a/dpd/src/Model_AM.py
+++ b/dpd/src/Model_AM.py
@@ -15,14 +15,30 @@ import numpy as np
import matplotlib.pyplot as plt
+def is_npfloat32(array):
+ assert isinstance(array, np.ndarray), type(array)
+ assert array.dtype == np.float32, array.dtype
+ assert array.flags.contiguous
+ assert not any(np.isnan(array))
+
+
def check_input_get_next_coefs(tx_dpd, rx_received):
- is_float32 = lambda x: (isinstance(x, np.ndarray) and
- x.dtype == np.float32 and
- x.flags.contiguous)
- assert is_float32(tx_dpd), \
- "tx_dpd is not float32 but {}".format(tx_dpd[0].dtype)
- assert is_float32(rx_received), \
- "rx_received is not float32 but {}".format(tx_dpd[0].dtype)
+ is_npfloat32(tx_dpd)
+ is_npfloat32(rx_received)
+
+
+def poly(sig):
+ return np.array([sig ** i for i in range(1, 6)]).T
+
+
+def fit_poly(tx_abs, rx_abs):
+ return np.linalg.lstsq(poly(rx_abs), tx_abs)[0]
+
+
+def calc_line(coefs, min_amp, max_amp):
+ rx_range = np.linspace(min_amp, max_amp)
+ tx_est = np.sum(poly(rx_range) * coefs, axis=1)
+ return tx_est, rx_range
class Model_AM:
@@ -40,8 +56,8 @@ class Model_AM:
def _plot(self, tx_dpd, rx_received, coefs_am, coefs_am_new):
if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot:
- tx_range, rx_est = self.calc_line(coefs_am, 0, 0.6)
- tx_range_new, rx_est_new = self.calc_line(coefs_am_new, 0, 0.6)
+ tx_range, rx_est = calc_line(coefs_am, 0, 0.6)
+ tx_range_new, rx_est_new = calc_line(coefs_am_new, 0, 0.6)
dt = datetime.datetime.now().isoformat()
fig_path = logging_path + "/" + dt + "_Model_AM.svg"
@@ -66,7 +82,6 @@ class Model_AM:
ax.set_title("Model_AM")
ax.set_xlabel("TX Amplitude")
ax.set_ylabel("RX Amplitude")
- xlim = ax.get_xlim()
ax.set_xlim(-0.5, 1.5)
ax.legend(loc=4)
@@ -74,21 +89,10 @@ class Model_AM:
fig.savefig(fig_path)
plt.close(fig)
- def poly(self, sig):
- return np.array([sig ** i for i in range(1, 6)]).T
-
- def fit_poly(self, tx_abs, rx_abs):
- return np.linalg.lstsq(self.poly(rx_abs), tx_abs)[0]
-
- def calc_line(self, coefs, min_amp, max_amp):
- rx_range = np.linspace(min_amp, max_amp)
- tx_est = np.sum(self.poly(rx_range) * coefs, axis=1)
- return tx_est, rx_range
-
def get_next_coefs(self, tx_dpd, rx_received, coefs_am):
check_input_get_next_coefs(tx_dpd, rx_received)
- coefs_am_new = self.fit_poly(tx_dpd, rx_received)
+ coefs_am_new = fit_poly(tx_dpd, rx_received)
coefs_am_new = coefs_am + \
self.learning_rate_am * (coefs_am_new - coefs_am)