aboutsummaryrefslogtreecommitdiffstats
path: root/dpd/src/Model_Poly.py
diff options
context:
space:
mode:
Diffstat (limited to 'dpd/src/Model_Poly.py')
-rw-r--r--dpd/src/Model_Poly.py21
1 files changed, 8 insertions, 13 deletions
diff --git a/dpd/src/Model_Poly.py b/dpd/src/Model_Poly.py
index 1faff24..f6c024c 100644
--- a/dpd/src/Model_Poly.py
+++ b/dpd/src/Model_Poly.py
@@ -37,40 +37,34 @@ def _check_input_get_next_coefs(tx_abs, rx_abs, phase_diff):
tx_abs.shape, phase_diff.shape)
-class Model_Poly:
+class Poly:
"""Calculates new coefficients using the measurement and the previous
coefficients"""
def __init__(self,
c,
- coefs_am,
- coefs_pm,
learning_rate_am=1.0,
learning_rate_pm=1.0,
plot=False):
- assert_np_float32(coefs_am)
- assert_np_float32(coefs_pm)
-
self.c = c
self.learning_rate_am = learning_rate_am
self.learning_rate_pm = learning_rate_pm
- self.coefs_am = coefs_am
- self.coefs_pm = coefs_pm
+ self.reset_coefs()
self.model_am = Model_AM.Model_AM(c, plot=True)
self.model_pm = Model_PM.Model_PM(c, plot=True)
self.plot = plot
- def get_default_coefs(self):
- self.coefs_am[:] = 0
+ def reset_coefs(self):
+ self.coefs_am = np.zeros(5, dtype=np.float32)
self.coefs_am[0] = 1
- self.coefs_pm[:] = 0
+ self.coefs_pm = np.zeros(5, dtype=np.float32)
return self.coefs_am, self.coefs_pm
- def get_next_coefs(self, tx_abs, rx_abs, phase_diff):
+ def train(self, tx_abs, rx_abs, phase_diff):
_check_input_get_next_coefs(tx_abs, rx_abs, phase_diff)
coefs_am_new = self.model_am.get_next_coefs(tx_abs, rx_abs, self.coefs_am)
@@ -79,7 +73,8 @@ class Model_Poly:
self.coefs_am = self.coefs_am + (coefs_am_new - self.coefs_am) * self.learning_rate_am
self.coefs_pm = self.coefs_pm + (coefs_pm_new - self.coefs_pm) * self.learning_rate_pm
- return self.coefs_am, self.coefs_pm
+ def get_dpd_data(self):
+ return "poly", self.coefs_am, self.coefs_pm
# The MIT License (MIT)
#