diff options
Diffstat (limited to 'dpd/src/Model_Poly.py')
-rw-r--r-- | dpd/src/Model_Poly.py | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/dpd/src/Model_Poly.py b/dpd/src/Model_Poly.py index 44e0483..16881ad 100644 --- a/dpd/src/Model_Poly.py +++ b/dpd/src/Model_Poly.py @@ -62,14 +62,19 @@ class Poly: self.coefs_pm = np.zeros(5, dtype=np.float32) return self.coefs_am, self.coefs_pm - def train(self, tx_abs, rx_abs, phase_diff): + def train(self, tx_abs, rx_abs, phase_diff, lr=None): """ :type tx_abs: np.ndarray :type rx_abs: np.ndarray :type phase_diff: np.ndarray + :type lr: float """ _check_input_get_next_coefs(tx_abs, rx_abs, phase_diff) + if not lr is None: + self.model_am.learning_rate_am = lr + self.model_pm.learning_rate_pm = lr + coefs_am_new = self.model_am.get_next_coefs(tx_abs, rx_abs, self.coefs_am) coefs_pm_new = self.model_pm.get_next_coefs(tx_abs, phase_diff, self.coefs_pm) |