diff options
Diffstat (limited to 'dpd/src/Model_Poly.py')
-rw-r--r-- | dpd/src/Model_Poly.py | 21 |
1 files changed, 8 insertions, 13 deletions
diff --git a/dpd/src/Model_Poly.py b/dpd/src/Model_Poly.py index 1faff24..f6c024c 100644 --- a/dpd/src/Model_Poly.py +++ b/dpd/src/Model_Poly.py @@ -37,40 +37,34 @@ def _check_input_get_next_coefs(tx_abs, rx_abs, phase_diff): tx_abs.shape, phase_diff.shape) -class Model_Poly: +class Poly: """Calculates new coefficients using the measurement and the previous coefficients""" def __init__(self, c, - coefs_am, - coefs_pm, learning_rate_am=1.0, learning_rate_pm=1.0, plot=False): - assert_np_float32(coefs_am) - assert_np_float32(coefs_pm) - self.c = c self.learning_rate_am = learning_rate_am self.learning_rate_pm = learning_rate_pm - self.coefs_am = coefs_am - self.coefs_pm = coefs_pm + self.reset_coefs() self.model_am = Model_AM.Model_AM(c, plot=True) self.model_pm = Model_PM.Model_PM(c, plot=True) self.plot = plot - def get_default_coefs(self): - self.coefs_am[:] = 0 + def reset_coefs(self): + self.coefs_am = np.zeros(5, dtype=np.float32) self.coefs_am[0] = 1 - self.coefs_pm[:] = 0 + self.coefs_pm = np.zeros(5, dtype=np.float32) return self.coefs_am, self.coefs_pm - def get_next_coefs(self, tx_abs, rx_abs, phase_diff): + def train(self, tx_abs, rx_abs, phase_diff): _check_input_get_next_coefs(tx_abs, rx_abs, phase_diff) coefs_am_new = self.model_am.get_next_coefs(tx_abs, rx_abs, self.coefs_am) @@ -79,7 +73,8 @@ class Model_Poly: self.coefs_am = self.coefs_am + (coefs_am_new - self.coefs_am) * self.learning_rate_am self.coefs_pm = self.coefs_pm + (coefs_pm_new - self.coefs_pm) * self.learning_rate_pm - return self.coefs_am, self.coefs_pm + def get_dpd_data(self): + return "poly", self.coefs_am, self.coefs_pm # The MIT License (MIT) # |