aboutsummaryrefslogtreecommitdiffstats
path: root/dpd/src/Model_AM.py
diff options
context:
space:
mode:
Diffstat (limited to 'dpd/src/Model_AM.py')
-rw-r--r--dpd/src/Model_AM.py15
1 files changed, 10 insertions, 5 deletions
diff --git a/dpd/src/Model_AM.py b/dpd/src/Model_AM.py
index 4b88e08..c6f7903 100644
--- a/dpd/src/Model_AM.py
+++ b/dpd/src/Model_AM.py
@@ -17,12 +17,12 @@ import matplotlib.pyplot as plt
def check_input_get_next_coefs(tx_dpd, rx_received):
is_float32 = lambda x: (isinstance(x, np.ndarray) and
- x.dtype == np.float32 and
- x.flags.contiguous)
+ x.dtype == np.float32 and
+ x.flags.contiguous)
assert is_float32(tx_dpd), \
- "tx_dpd is not float32 but {}".format(tx_dpd[0].dtype)
+ "tx_dpd is not float32 but {}".format(tx_dpd[0].dtype)
assert is_float32(rx_received), \
- "rx_received is not float32 but {}".format(tx_dpd[0].dtype)
+ "rx_received is not float32 but {}".format(tx_dpd[0].dtype)
class Model_AM:
@@ -31,7 +31,7 @@ class Model_AM:
def __init__(self,
c,
- learning_rate_am=0.1,
+ learning_rate_am=1,
plot=False):
self.c = c
@@ -66,6 +66,8 @@ class Model_AM:
ax.set_title("Model_AM")
ax.set_xlabel("TX Amplitude")
ax.set_ylabel("RX Amplitude")
+ xlim = ax.get_xlim()
+ ax.set_xlim(max(xlim[0], -1), min(xlim[1], 2))
ax.legend(loc=4)
fig.tight_layout()
@@ -87,6 +89,9 @@ class Model_AM:
check_input_get_next_coefs(tx_dpd, rx_received)
coefs_am_new = self.fit_poly(tx_dpd, rx_received)
+ coefs_am_new = coefs_am + \
+ self.learning_rate_am * (coefs_am_new - coefs_am)
+
self._plot(tx_dpd, rx_received, coefs_am, coefs_am_new)
return coefs_am_new