diff options
| author | andreas128 <Andreas> | 2017-09-27 20:59:17 +0200 | 
|---|---|---|
| committer | andreas128 <Andreas> | 2017-09-27 20:59:17 +0200 | 
| commit | 9954be2a2d94c6d7ebed8b36364d02e97084b9f2 (patch) | |
| tree | 2173908a9b3103eaff8d42812accd167312232d8 /dpd/src/Model_AM.py | |
| parent | 071088f747f5629f60b01017bdcff5161efb7ba5 (diff) | |
| download | dabmod-9954be2a2d94c6d7ebed8b36364d02e97084b9f2.tar.gz dabmod-9954be2a2d94c6d7ebed8b36364d02e97084b9f2.tar.bz2 dabmod-9954be2a2d94c6d7ebed8b36364d02e97084b9f2.zip | |
Change fixed learning rate and number of measurements to heuristic
Diffstat (limited to 'dpd/src/Model_AM.py')
| -rw-r--r-- | dpd/src/Model_AM.py | 15 | 
1 files changed, 10 insertions, 5 deletions
| diff --git a/dpd/src/Model_AM.py b/dpd/src/Model_AM.py index 4b88e08..c6f7903 100644 --- a/dpd/src/Model_AM.py +++ b/dpd/src/Model_AM.py @@ -17,12 +17,12 @@ import matplotlib.pyplot as plt  def check_input_get_next_coefs(tx_dpd, rx_received):      is_float32 = lambda x: (isinstance(x, np.ndarray) and -                        x.dtype == np.float32 and -                        x.flags.contiguous) +                            x.dtype == np.float32 and +                            x.flags.contiguous)      assert is_float32(tx_dpd), \ -           "tx_dpd is not float32 but {}".format(tx_dpd[0].dtype) +        "tx_dpd is not float32 but {}".format(tx_dpd[0].dtype)      assert is_float32(rx_received), \ -            "rx_received is not float32 but {}".format(tx_dpd[0].dtype) +        "rx_received is not float32 but {}".format(tx_dpd[0].dtype)  class Model_AM: @@ -31,7 +31,7 @@ class Model_AM:      def __init__(self,                   c, -                 learning_rate_am=0.1, +                 learning_rate_am=1,                   plot=False):          self.c = c @@ -66,6 +66,8 @@ class Model_AM:              ax.set_title("Model_AM")              ax.set_xlabel("TX Amplitude")              ax.set_ylabel("RX Amplitude") +            xlim = ax.get_xlim() +            ax.set_xlim(max(xlim[0], -1), min(xlim[1], 2))              ax.legend(loc=4)              fig.tight_layout() @@ -87,6 +89,9 @@ class Model_AM:          check_input_get_next_coefs(tx_dpd, rx_received)          coefs_am_new = self.fit_poly(tx_dpd, rx_received) +        coefs_am_new = coefs_am + \ +                       self.learning_rate_am * (coefs_am_new - coefs_am) +          self._plot(tx_dpd, rx_received, coefs_am, coefs_am_new)          return coefs_am_new | 
