diff options
| author | andreas128 <Andreas> | 2017-09-28 18:59:35 +0200 | 
|---|---|---|
| committer | andreas128 <Andreas> | 2017-09-28 18:59:35 +0200 | 
| commit | 253be52c23528544d54a59b649a60193fffb2848 (patch) | |
| tree | 67bd74ca1f35ec0dc7dee34207b5aa652443e485 /dpd/src/Model_AM.py | |
| parent | 74765b949c8d597ec906fd49733a035028095d54 (diff) | |
| download | dabmod-253be52c23528544d54a59b649a60193fffb2848.tar.gz dabmod-253be52c23528544d54a59b649a60193fffb2848.tar.bz2 dabmod-253be52c23528544d54a59b649a60193fffb2848.zip | |
Cleanup
Diffstat (limited to 'dpd/src/Model_AM.py')
| -rw-r--r-- | dpd/src/Model_AM.py | 48 | 
1 files changed, 26 insertions, 22 deletions
| diff --git a/dpd/src/Model_AM.py b/dpd/src/Model_AM.py index 85f6495..cdc3de1 100644 --- a/dpd/src/Model_AM.py +++ b/dpd/src/Model_AM.py @@ -15,14 +15,30 @@ import numpy as np  import matplotlib.pyplot as plt +def is_npfloat32(array): +    assert isinstance(array, np.ndarray), type(array) +    assert array.dtype == np.float32, array.dtype +    assert array.flags.contiguous +    assert not any(np.isnan(array)) + +  def check_input_get_next_coefs(tx_dpd, rx_received): -    is_float32 = lambda x: (isinstance(x, np.ndarray) and -                            x.dtype == np.float32 and -                            x.flags.contiguous) -    assert is_float32(tx_dpd), \ -        "tx_dpd is not float32 but {}".format(tx_dpd[0].dtype) -    assert is_float32(rx_received), \ -        "rx_received is not float32 but {}".format(tx_dpd[0].dtype) +    is_npfloat32(tx_dpd) +    is_npfloat32(rx_received) + + +def poly(sig): +    return np.array([sig ** i for i in range(1, 6)]).T + + +def fit_poly(tx_abs, rx_abs): +    return np.linalg.lstsq(poly(rx_abs), tx_abs)[0] + + +def calc_line(coefs, min_amp, max_amp): +    rx_range = np.linspace(min_amp, max_amp) +    tx_est = np.sum(poly(rx_range) * coefs, axis=1) +    return tx_est, rx_range  class Model_AM: @@ -40,8 +56,8 @@ class Model_AM:      def _plot(self, tx_dpd, rx_received, coefs_am, coefs_am_new):          if logging.getLogger().getEffectiveLevel() == logging.DEBUG and self.plot: -            tx_range, rx_est = self.calc_line(coefs_am, 0, 0.6) -            tx_range_new, rx_est_new = self.calc_line(coefs_am_new, 0, 0.6) +            tx_range, rx_est = calc_line(coefs_am, 0, 0.6) +            tx_range_new, rx_est_new = calc_line(coefs_am_new, 0, 0.6)              dt = datetime.datetime.now().isoformat()              fig_path = logging_path + "/" + dt + "_Model_AM.svg" @@ -66,7 +82,6 @@ class Model_AM:              ax.set_title("Model_AM")              ax.set_xlabel("TX Amplitude")              ax.set_ylabel("RX Amplitude") -            xlim = ax.get_xlim()              ax.set_xlim(-0.5, 1.5)              ax.legend(loc=4) @@ -74,21 +89,10 @@ class Model_AM:              fig.savefig(fig_path)              plt.close(fig) -    def poly(self, sig): -        return np.array([sig ** i for i in range(1, 6)]).T - -    def fit_poly(self, tx_abs, rx_abs): -        return np.linalg.lstsq(self.poly(rx_abs), tx_abs)[0] - -    def calc_line(self, coefs, min_amp, max_amp): -        rx_range = np.linspace(min_amp, max_amp) -        tx_est = np.sum(self.poly(rx_range) * coefs, axis=1) -        return tx_est, rx_range -      def get_next_coefs(self, tx_dpd, rx_received, coefs_am):          check_input_get_next_coefs(tx_dpd, rx_received) -        coefs_am_new = self.fit_poly(tx_dpd, rx_received) +        coefs_am_new = fit_poly(tx_dpd, rx_received)          coefs_am_new = coefs_am + \                         self.learning_rate_am * (coefs_am_new - coefs_am) | 
