aboutsummaryrefslogtreecommitdiffstats
path: root/edi
diff options
context:
space:
mode:
authorMatthias P. Braendli <matthias.braendli@mpb.li>2014-06-16 20:46:54 +0200
committerMatthias P. Braendli <matthias.braendli@mpb.li>2014-06-16 20:46:54 +0200
commitbb7fd26a46a076a0057fbd2296a4ac4e1af73c27 (patch)
tree1cf004c76aeaea88a1d750bb1406c4db30299092 /edi
parentf5d94805b7f42f20c0468e545741866503b9e17f (diff)
downloadmmbtools-aux-bb7fd26a46a076a0057fbd2296a4ac4e1af73c27.tar.gz
mmbtools-aux-bb7fd26a46a076a0057fbd2296a4ac4e1af73c27.tar.bz2
mmbtools-aux-bb7fd26a46a076a0057fbd2296a4ac4e1af73c27.zip
Add reedsolo.py
Diffstat (limited to 'edi')
-rw-r--r--edi/reedsolo.py254
1 files changed, 254 insertions, 0 deletions
diff --git a/edi/reedsolo.py b/edi/reedsolo.py
new file mode 100644
index 0000000..c310d22
--- /dev/null
+++ b/edi/reedsolo.py
@@ -0,0 +1,254 @@
+r"""
+Reed Solomon
+============
+
+A pure-python `Reed Solomon <http://en.wikipedia.org/wiki/Reed%E2%80%93Solomon_error_correction>`_
+encoder/decoder, based on the wonderful tutorial at
+`wikiversity <http://en.wikiversity.org/wiki/Reed%E2%80%93Solomon_codes_for_coders>`_,
+written by "Bobmath".
+
+I only consolidated the code a little and added exceptions and a simple API.
+To my understanding, the algorithm can correct up to ``nsym/2`` of the errors in
+the message, where ``nsym`` is the number of bytes in the error correction code (ECC).
+The code should work on pretty much any reasonable version of python (2.4-3.2),
+but I'm only testing on 2.5-3.2.
+
+.. note::
+ I claim no authorship of the code, and take no responsibility for the correctness
+ of the algorithm. It's way too much finite-field algebra for me :)
+
+ I've released this package as I needed an ECC codec for another project I'm working on,
+ and I couldn't find anything on the web (that still works).
+
+ The algorithm itself can handle messages up to 255 bytes, including the ECC bytes. The
+ ``RSCodec`` class will split longer messages into chunks and encode/decode them separately;
+ it shouldn't make a difference from an API perspective.
+
+::
+
+ >>> rs = RSCodec(10)
+ >>> rs.encode([1,2,3,4])
+ b'\x01\x02\x03\x04,\x9d\x1c+=\xf8h\xfa\x98M'
+ >>> rs.encode(b'hello world')
+ b'hello world\xed%T\xc4\xfd\xfd\x89\xf3\xa8\xaa'
+ >>> rs.decode(b'hello world\xed%T\xc4\xfd\xfd\x89\xf3\xa8\xaa')
+ b'hello world'
+ >>> rs.decode(b'heXlo worXd\xed%T\xc4\xfdX\x89\xf3\xa8\xaa') # 3 errors
+ b'hello world'
+ >>> rs.decode(b'hXXXo worXd\xed%T\xc4\xfdX\x89\xf3\xa8\xaa') # 5 errors
+ b'hello world'
+ >>> rs.decode(b'hXXXo worXd\xed%T\xc4\xfdXX\xf3\xa8\xaa') # 6 errors - fail
+ Traceback (most recent call last):
+ ...
+ ReedSolomonError: Could not locate error
+
+ >>> rs = RSCodec(12)
+ >>> rs.encode(b'hello world')
+ b'hello world?Ay\xb2\xbc\xdc\x01q\xb9\xe3\xe2='
+ >>> rs.decode(b'hello worXXXXy\xb2XX\x01q\xb9\xe3\xe2=') # 6 errors - ok
+ b'hello world'
+"""
+
+try:
+ bytearray
+except NameError:
+ from array import array
+ def bytearray(obj = 0, encoding = "utf8"):
+ if isinstance(obj, str):
+ obj = [ord(ch) for ch in obj.encode("utf8")]
+ elif isinstance(obj, int):
+ obj = [0] * obj
+ return array("B", obj)
+
+
+class ReedSolomonError(Exception):
+ pass
+
+
+gf_exp = [1] * 512
+gf_log = [0] * 256
+x = 1
+for i in range(1, 255):
+ x <<= 1
+ if x & 0x100:
+ x ^= 0x11d
+ gf_exp[i] = x
+ gf_log[x] = i
+for i in range(255, 512):
+ gf_exp[i] = gf_exp[i - 255]
+
+def gf_mul(x, y):
+ if x == 0 or y == 0:
+ return 0
+ return gf_exp[gf_log[x] + gf_log[y]]
+
+def gf_div(x, y):
+ if y == 0:
+ raise ZeroDivisionError()
+ if x == 0:
+ return 0
+ return gf_exp[gf_log[x] + 255 - gf_log[y]]
+
+def gf_poly_scale(p, x):
+ return [gf_mul(p[i], x) for i in range(0, len(p))]
+
+def gf_poly_add(p, q):
+ r = [0] * max(len(p), len(q))
+ for i in range(0, len(p)):
+ r[i + len(r) - len(p)] = p[i]
+ for i in range(0, len(q)):
+ r[i + len(r) - len(q)] ^= q[i]
+ return r
+
+def gf_poly_mul(p, q):
+ r = [0] * (len(p) + len(q) - 1)
+ for j in range(0, len(q)):
+ for i in range(0, len(p)):
+ r[i + j] ^= gf_mul(p[i], q[j])
+ return r
+
+def gf_poly_eval(p, x):
+ y = p[0]
+ for i in range(1, len(p)):
+ y = gf_mul(y, x) ^ p[i]
+ return y
+
+def rs_generator_poly(nsym):
+ g = [1]
+ for i in range(0, nsym):
+ g = gf_poly_mul(g, [1, gf_exp[i]])
+ return g
+
+def rs_encode_msg(msg_in, nsym):
+ if len(msg_in) + nsym > 255:
+ raise ValueError("message too long")
+ gen = rs_generator_poly(nsym)
+ msg_out = bytearray(len(msg_in) + nsym)
+ msg_out[:len(msg_in)] = msg_in
+ for i in range(0, len(msg_in)):
+ coef = msg_out[i]
+ if coef != 0:
+ for j in range(0, len(gen)):
+ msg_out[i + j] ^= gf_mul(gen[j], coef)
+ msg_out[:len(msg_in)] = msg_in
+ return msg_out
+
+def rs_calc_syndromes(msg, nsym):
+ return [gf_poly_eval(msg, gf_exp[i]) for i in range(nsym)]
+
+def rs_correct_errata(msg, synd, pos):
+ # calculate error locator polynomial
+ q = [1]
+ for i in range(0, len(pos)):
+ x = gf_exp[len(msg) - 1 - pos[i]]
+ q = gf_poly_mul(q, [x, 1])
+ # calculate error evaluator polynomial
+ p = synd[0:len(pos)]
+ p.reverse()
+ p = gf_poly_mul(p, q)
+ p = p[len(p) - len(pos):len(p)]
+ # formal derivative of error locator eliminates even terms
+ q = q[len(q) & 1:len(q):2]
+ # compute corrections
+ for i in range(0, len(pos)):
+ x = gf_exp[pos[i] + 256 - len(msg)]
+ y = gf_poly_eval(p, x)
+ z = gf_poly_eval(q, gf_mul(x, x))
+ msg[pos[i]] ^= gf_div(y, gf_mul(x, z))
+
+def rs_find_errors(synd, nmess):
+ # find error locator polynomial with Berlekamp-Massey algorithm
+ err_poly = [1]
+ old_poly = [1]
+ for i in range(0, len(synd)):
+ old_poly.append(0)
+ delta = synd[i]
+ for j in range(1, len(err_poly)):
+ delta ^= gf_mul(err_poly[len(err_poly) - 1 - j], synd[i - j])
+ if delta != 0:
+ if len(old_poly) > len(err_poly):
+ new_poly = gf_poly_scale(old_poly, delta)
+ old_poly = gf_poly_scale(err_poly, gf_div(1, delta))
+ err_poly = new_poly
+ err_poly = gf_poly_add(err_poly, gf_poly_scale(old_poly, delta))
+ errs = len(err_poly) - 1
+ if errs * 2 > len(synd):
+ raise ReedSolomonError("Too many errors to correct")
+ # find zeros of error polynomial
+ err_pos = []
+ for i in range(0, nmess):
+ if gf_poly_eval(err_poly, gf_exp[255 - i]) == 0:
+ err_pos.append(nmess - 1 - i)
+ if len(err_pos) != errs:
+ return None # couldn't find error locations
+ return err_pos
+
+def rs_forney_syndromes(synd, pos, nmess):
+ fsynd = list(synd) # make a copy
+ for i in range(0, len(pos)):
+ x = gf_exp[nmess - 1 - pos[i]]
+ for i in range(0, len(fsynd) - 1):
+ fsynd[i] = gf_mul(fsynd[i], x) ^ fsynd[i + 1]
+ fsynd.pop()
+ return fsynd
+
+def rs_correct_msg(msg_in, nsym):
+ if len(msg_in) > 255:
+ raise ValueError("message too long")
+ msg_out = list(msg_in) # copy of message
+ # find erasures
+ erase_pos = []
+ for i in range(0, len(msg_out)):
+ if msg_out[i] < 0:
+ msg_out[i] = 0
+ erase_pos.append(i)
+ if len(erase_pos) > nsym:
+ raise ReedSolomonError("Too many erasures to correct")
+ synd = rs_calc_syndromes(msg_out, nsym)
+ if max(synd) == 0:
+ return msg_out[:-nsym] # no errors
+ fsynd = rs_forney_syndromes(synd, erase_pos, len(msg_out))
+ err_pos = rs_find_errors(fsynd, len(msg_out))
+ if err_pos is None:
+ raise ReedSolomonError("Could not locate error")
+ rs_correct_errata(msg_out, synd, erase_pos + err_pos)
+ synd = rs_calc_syndromes(msg_out, nsym)
+ if max(synd) > 0:
+ raise ReedSolomonError("Could not correct message")
+ return msg_out[:-nsym]
+
+
+#===================================================================================================
+# API
+#===================================================================================================
+class RSCodec(object):
+ """
+ A Reed Solomon encoder/decoder. After initializing the object, use ``encode`` to encode a
+ (byte)string to include the RS correction code, and pass such an encoded (byte)string to
+ ``decode`` to extract the original message (if the number of errors allows for correct decoding).
+ The ``nsym`` argument is the length of the correction code, and it determines the number of
+ error bytes (if I understand this correctly, half of ``nsym`` is correctable)
+ """
+ def __init__(self, nsym=10):
+ self.nsym = nsym
+
+ def encode(self, data):
+ if isinstance(data, str):
+ data = bytearray(data, "utf-8")
+ chunk_size = 255 - self.nsym
+ enc = bytearray()
+ for i in range(0, len(data), chunk_size):
+ chunk = data[i:i+chunk_size]
+ enc.extend(rs_encode_msg(chunk, self.nsym))
+ return enc
+
+ def decode(self, data):
+ if isinstance(data, str):
+ data = bytearray(data, "utf-8")
+ dec = bytearray()
+ for i in range(0, len(data), 255):
+ chunk = data[i:i+255]
+ dec.extend(rs_correct_msg(chunk, self.nsym))
+ return dec
+
+