aboutsummaryrefslogtreecommitdiffstats
path: root/src/fft_lag.py
diff options
context:
space:
mode:
authorandreas128 <Andreas>2017-02-26 00:22:17 +0000
committerandreas128 <Andreas>2017-02-26 00:22:17 +0000
commit2476ba7e03035caa1db7ad8a0dc6718501680c44 (patch)
tree98dfad3b0976186cbaebcce06bbd04131da0f2d7 /src/fft_lag.py
parent6e0b2512e45b7a6ca03187814742cb0fe08964cb (diff)
downloadODR-StaticPrecorrection-2476ba7e03035caa1db7ad8a0dc6718501680c44.tar.gz
ODR-StaticPrecorrection-2476ba7e03035caa1db7ad8a0dc6718501680c44.tar.bz2
ODR-StaticPrecorrection-2476ba7e03035caa1db7ad8a0dc6718501680c44.zip
Add sub sample accurate fft lag finder
Diffstat (limited to 'src/fft_lag.py')
-rw-r--r--src/fft_lag.py165
1 files changed, 165 insertions, 0 deletions
diff --git a/src/fft_lag.py b/src/fft_lag.py
new file mode 100644
index 0000000..2d8d733
--- /dev/null
+++ b/src/fft_lag.py
@@ -0,0 +1,165 @@
+import numpy as np
+from numpy.fft import rfftn, irfftn
+from scipy.fftpack import (fft, ifft, ifftshift, fft2, ifft2, fftn,
+ ifftn, fftfreq)
+from scipy._lib._version import NumpyVersion
+from scipy import signal
+
+########################################
+##### Test functions
+########################################
+def create_triangular_chirp(f = 0.2, periods = 100):
+ #f ... per sample
+ t_max = periods / float(f)
+ t = np.arange(t_max)
+ sig = signal.chirp(t, 0, t_max, f)
+ triangle = np.linspace(0, t_max, num = t_max) / t_max
+ sig = np.multiply(sig, triangle)
+ return sig
+
+def add_delay_and_padding(sig, delay=0, padding=0):
+ n_sig = sig.shape[0]
+ n_total = n_sig + delay + padding
+ ret = np.zeros((n_total), dtype=sig.dtype)
+ ret[delay:delay + n_sig] = sig
+ return ret
+
+def create_test_signal(f=0.05, periods=100, delay=0, padding=0):
+ return add_delay_and_padding(create_triangular_chirp(f=f, periods=periods),
+ delay=delay, padding=padding)
+
+def visualize_test_signal(f=0.05, periods=100, delay=10, padding=100):
+ plt.plot(create_test_signal(f=f, periods=periods, delay=delay, padding=padding))
+
+def down_sample(sig, n_every):
+ return sig[::n_every]
+
+
+
+_rfft_mt_safe = (NumpyVersion(np.__version__) >= '1.9.0.dev-e24486e')
+
+def _next_regular(target):
+ """
+ Find the next regular number greater than or equal to target.
+ Regular numbers are composites of the prime factors 2, 3, and 5.
+ Also known as 5-smooth numbers or Hamming numbers, these are the optimal
+ size for inputs to FFTPACK.
+ Target must be a positive integer.
+ """
+ if target <= 6:
+ return target
+
+ # Quickly check if it's already a power of 2
+ if not (target & (target-1)):
+ return target
+
+ match = float('inf') # Anything found will be smaller
+ p5 = 1
+ while p5 < target:
+ p35 = p5
+ while p35 < target:
+ # Ceiling integer division, avoiding conversion to float
+ # (quotient = ceil(target / p35))
+ quotient = -(-target // p35)
+
+ # Quickly find next power of 2 >= quotient
+ try:
+ p2 = 2**((quotient - 1).bit_length())
+ except AttributeError:
+ # Fallback for Python <2.7
+ p2 = 2**(len(bin(quotient - 1)) - 2)
+
+ N = p2 * p35
+ if N == target:
+ return N
+ elif N < match:
+ match = N
+ p35 *= 3
+ if p35 == target:
+ return p35
+ if p35 < match:
+ match = p35
+ p5 *= 5
+ if p5 == target:
+ return p5
+ if p5 < match:
+ match = p5
+ return match
+
+def fft_lag(s1, s2, n_up = 1, debug=False):
+ if debug:
+ import matplotlib.pyplot as plt
+ s1 = np.flipud(s1) #mult becomes convolution -> filp to get correlation
+
+ sh1 = np.array(s1.shape)
+ sh2 = np.array(s2.shape)
+ complex_result = (np.issubdtype(s1.dtype, np.complex) or
+ np.issubdtype(s2.dtype, np.complex))
+ shape = (sh1 + sh2 - 1)
+
+ fshape = [_next_regular(int(d)) for d in shape]
+ fslice = tuple([slice(0, int(sz)) for sz in shape])
+
+ def upsample_fft(s_fft, n_up):
+ n = s_fft.shape[0]
+ dtype = s_fft.dtype
+ ret = None
+ if n % 2 == 0:
+ ret = np.zeros((n*n_up), dtype=dtype)
+ ret[:n/2] = s_fft[0:n/2]
+ ret[-n/2:] = s_fft[-n/2:]
+ else:
+ ret = np.zeros((n*n_up), dtype=dtype)
+ ret[:n/2] = s_fft[0:n/2]
+ ret[n/2+1] = s_fft[n/2+1] / 2.
+ ret[-(n/2+1)] = s_fft[n/2+1] / 2.
+ ret[-n/2:] = s_fft[-n/2:]
+ ret = ret * n_up
+ return ret
+
+ s1_fft = np.fft.fftn(s1, fshape)
+ s2_fft = np.fft.fftn(s2, fshape)
+ s1_s2_fft = upsample_fft(s1_fft * s2_fft, n_up)
+ s1_s2 = np.fft.ifftn(s1_s2_fft)
+
+ delay = np.argmax(s1_s2) / float(n_up) - sh1[0] + 1
+ #peak of correlation - size of original signal (robust against padding at the end)
+
+ if debug:
+ plt.subplot(411); plt.plot(s1_fft)
+ plt.subplot(412); plt.plot(s2_fft)
+ plt.subplot(413); plt.plot(s1_s2_fft)
+ plt.subplot(414); plt.plot(s1_s2)
+ plt.show()
+ print(s1_s2.shape, s1_fft.shape, s2_fft.shape, fshape, fslice)
+
+
+ #return np.argmax(s1_s2)/float(n_up)
+ #return np.argmax(s1_s2)/float(n_up) - s2.shape[0] + 1
+ return delay
+
+def fft_lag_random_test(n_tests=1000):
+ def r():
+ return np.random.randint(0, 1000)
+ def rand(n):
+ return [np.random.randint(0, 1000) for i in range(n)]
+
+ for i in range(n_tests):
+ debug = (i == 0)
+ d1, d2, p1, p2 = rand(4)
+ n_down = 5
+ n_up = 32
+ sig1 = down_sample(create_test_signal(delay=d1, padding=p1), n_down)
+ sig2 = down_sample(create_test_signal(delay=d2, padding=p2), n_down)
+
+ d1, d2, p1, p2 = [x/float(n_down) for x in [d1, d2, p1, p2]]
+ delay = d2 - d1
+ delay_meas = fft_lag(sig1, sig2, n_up=n_up, debug = debug)
+ tol = 1./n_up
+ error = abs(delay - delay_meas)
+ assert(error < tol)
+ print("%d tests within tolerance" % n_tests)
+
+
+if __name__ == "__main__":
+ fft_lag_random_test()