summaryrefslogtreecommitdiffstats
path: root/dpd/src/Model_Poly.py
diff options
context:
space:
mode:
authorandreas128 <Andreas>2017-09-27 20:59:17 +0200
committerandreas128 <Andreas>2017-09-27 20:59:17 +0200
commit9954be2a2d94c6d7ebed8b36364d02e97084b9f2 (patch)
tree2173908a9b3103eaff8d42812accd167312232d8 /dpd/src/Model_Poly.py
parent071088f747f5629f60b01017bdcff5161efb7ba5 (diff)
downloaddabmod-9954be2a2d94c6d7ebed8b36364d02e97084b9f2.tar.gz
dabmod-9954be2a2d94c6d7ebed8b36364d02e97084b9f2.tar.bz2
dabmod-9954be2a2d94c6d7ebed8b36364d02e97084b9f2.zip
Change fixed learning rate and number of measurements to heuristic
Diffstat (limited to 'dpd/src/Model_Poly.py')
-rw-r--r--dpd/src/Model_Poly.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/dpd/src/Model_Poly.py b/dpd/src/Model_Poly.py
index 44e0483..16881ad 100644
--- a/dpd/src/Model_Poly.py
+++ b/dpd/src/Model_Poly.py
@@ -62,14 +62,19 @@ class Poly:
self.coefs_pm = np.zeros(5, dtype=np.float32)
return self.coefs_am, self.coefs_pm
- def train(self, tx_abs, rx_abs, phase_diff):
+ def train(self, tx_abs, rx_abs, phase_diff, lr=None):
"""
:type tx_abs: np.ndarray
:type rx_abs: np.ndarray
:type phase_diff: np.ndarray
+ :type lr: float
"""
_check_input_get_next_coefs(tx_abs, rx_abs, phase_diff)
+ if not lr is None:
+ self.model_am.learning_rate_am = lr
+ self.model_pm.learning_rate_pm = lr
+
coefs_am_new = self.model_am.get_next_coefs(tx_abs, rx_abs, self.coefs_am)
coefs_pm_new = self.model_pm.get_next_coefs(tx_abs, phase_diff, self.coefs_pm)