diff options
| -rw-r--r-- | edi/reedsolo.py | 254 | 
1 files changed, 254 insertions, 0 deletions
| diff --git a/edi/reedsolo.py b/edi/reedsolo.py new file mode 100644 index 0000000..c310d22 --- /dev/null +++ b/edi/reedsolo.py @@ -0,0 +1,254 @@ +r""" +Reed Solomon +============ + +A pure-python `Reed Solomon <http://en.wikipedia.org/wiki/Reed%E2%80%93Solomon_error_correction>`_ +encoder/decoder, based on the wonderful tutorial at  +`wikiversity <http://en.wikiversity.org/wiki/Reed%E2%80%93Solomon_codes_for_coders>`_, +written by "Bobmath". + +I only consolidated the code a little and added exceptions and a simple API.  +To my understanding, the algorithm can correct up to ``nsym/2`` of the errors in  +the message, where ``nsym`` is the number of bytes in the error correction code (ECC). +The code should work on pretty much any reasonable version of python (2.4-3.2),  +but I'm only testing on 2.5-3.2. + +.. note:: +   I claim no authorship of the code, and take no responsibility for the correctness  +   of the algorithm. It's way too much finite-field algebra for me :) +    +   I've released this package as I needed an ECC codec for another project I'm working on,  +   and I couldn't find anything on the web (that still works). +    +   The algorithm itself can handle messages up to 255 bytes, including the ECC bytes. The +   ``RSCodec`` class will split longer messages into chunks and encode/decode them separately; +   it shouldn't make a difference from an API perspective. + +:: + +    >>> rs = RSCodec(10) +    >>> rs.encode([1,2,3,4]) +    b'\x01\x02\x03\x04,\x9d\x1c+=\xf8h\xfa\x98M' +    >>> rs.encode(b'hello world') +    b'hello world\xed%T\xc4\xfd\xfd\x89\xf3\xa8\xaa' +    >>> rs.decode(b'hello world\xed%T\xc4\xfd\xfd\x89\xf3\xa8\xaa') +    b'hello world' +    >>> rs.decode(b'heXlo worXd\xed%T\xc4\xfdX\x89\xf3\xa8\xaa')     # 3 errors +    b'hello world' +    >>> rs.decode(b'hXXXo worXd\xed%T\xc4\xfdX\x89\xf3\xa8\xaa')     # 5 errors +    b'hello world' +    >>> rs.decode(b'hXXXo worXd\xed%T\xc4\xfdXX\xf3\xa8\xaa')        # 6 errors - fail +    Traceback (most recent call last): +      ... +    ReedSolomonError: Could not locate error + +    >>> rs = RSCodec(12) +    >>> rs.encode(b'hello world') +    b'hello world?Ay\xb2\xbc\xdc\x01q\xb9\xe3\xe2=' +    >>> rs.decode(b'hello worXXXXy\xb2XX\x01q\xb9\xe3\xe2=')         # 6 errors - ok +    b'hello world' +""" + +try: +    bytearray +except NameError: +    from array import array +    def bytearray(obj = 0, encoding = "utf8"): +        if isinstance(obj, str): +            obj = [ord(ch) for ch in obj.encode("utf8")] +        elif isinstance(obj, int): +            obj = [0] * obj +        return array("B", obj) + + +class ReedSolomonError(Exception): +    pass + + +gf_exp = [1] * 512 +gf_log = [0] * 256 +x = 1 +for i in range(1, 255): +    x <<= 1 +    if x & 0x100: +        x ^= 0x11d +    gf_exp[i] = x +    gf_log[x] = i +for i in range(255, 512): +    gf_exp[i] = gf_exp[i - 255] + +def gf_mul(x, y): +    if x == 0 or y == 0: +        return 0 +    return gf_exp[gf_log[x] + gf_log[y]] + +def gf_div(x, y): +    if y == 0: +        raise ZeroDivisionError() +    if x == 0: +        return 0 +    return gf_exp[gf_log[x] + 255 - gf_log[y]] + +def gf_poly_scale(p, x): +    return [gf_mul(p[i], x) for i in range(0, len(p))] + +def gf_poly_add(p, q): +    r = [0] * max(len(p), len(q)) +    for i in range(0, len(p)): +        r[i + len(r) - len(p)] = p[i] +    for i in range(0, len(q)): +        r[i + len(r) - len(q)] ^= q[i] +    return r + +def gf_poly_mul(p, q): +    r = [0] * (len(p) + len(q) - 1) +    for j in range(0, len(q)): +        for i in range(0, len(p)): +            r[i + j] ^= gf_mul(p[i], q[j]) +    return r + +def gf_poly_eval(p, x): +    y = p[0] +    for i in range(1, len(p)): +        y = gf_mul(y, x) ^ p[i] +    return y + +def rs_generator_poly(nsym): +    g = [1] +    for i in range(0, nsym): +        g = gf_poly_mul(g, [1, gf_exp[i]]) +    return g + +def rs_encode_msg(msg_in, nsym): +    if len(msg_in) + nsym > 255: +        raise ValueError("message too long") +    gen = rs_generator_poly(nsym) +    msg_out = bytearray(len(msg_in) + nsym) +    msg_out[:len(msg_in)] = msg_in +    for i in range(0, len(msg_in)): +        coef = msg_out[i] +        if coef != 0: +            for j in range(0, len(gen)): +                msg_out[i + j] ^= gf_mul(gen[j], coef) +    msg_out[:len(msg_in)] = msg_in +    return msg_out + +def rs_calc_syndromes(msg, nsym): +    return [gf_poly_eval(msg, gf_exp[i]) for i in range(nsym)] + +def rs_correct_errata(msg, synd, pos): +    # calculate error locator polynomial +    q = [1] +    for i in range(0, len(pos)): +        x = gf_exp[len(msg) - 1 - pos[i]] +        q = gf_poly_mul(q, [x, 1]) +    # calculate error evaluator polynomial +    p = synd[0:len(pos)] +    p.reverse() +    p = gf_poly_mul(p, q) +    p = p[len(p) - len(pos):len(p)] +    # formal derivative of error locator eliminates even terms +    q = q[len(q) & 1:len(q):2] +    # compute corrections +    for i in range(0, len(pos)): +        x = gf_exp[pos[i] + 256 - len(msg)] +        y = gf_poly_eval(p, x) +        z = gf_poly_eval(q, gf_mul(x, x)) +        msg[pos[i]] ^= gf_div(y, gf_mul(x, z)) + +def rs_find_errors(synd, nmess): +    # find error locator polynomial with Berlekamp-Massey algorithm +    err_poly = [1] +    old_poly = [1] +    for i in range(0, len(synd)): +        old_poly.append(0) +        delta = synd[i] +        for j in range(1, len(err_poly)): +            delta ^= gf_mul(err_poly[len(err_poly) - 1 - j], synd[i - j]) +        if delta != 0: +            if len(old_poly) > len(err_poly): +                new_poly = gf_poly_scale(old_poly, delta) +                old_poly = gf_poly_scale(err_poly, gf_div(1, delta)) +                err_poly = new_poly +            err_poly = gf_poly_add(err_poly, gf_poly_scale(old_poly, delta)) +    errs = len(err_poly) - 1 +    if errs * 2 > len(synd): +        raise ReedSolomonError("Too many errors to correct") +    # find zeros of error polynomial +    err_pos = [] +    for i in range(0, nmess): +        if gf_poly_eval(err_poly, gf_exp[255 - i]) == 0: +            err_pos.append(nmess - 1 - i) +    if len(err_pos) != errs: +        return None    # couldn't find error locations +    return err_pos + +def rs_forney_syndromes(synd, pos, nmess): +    fsynd = list(synd)      # make a copy +    for i in range(0, len(pos)): +        x = gf_exp[nmess - 1 - pos[i]] +        for i in range(0, len(fsynd) - 1): +            fsynd[i] = gf_mul(fsynd[i], x) ^ fsynd[i + 1] +        fsynd.pop() +    return fsynd + +def rs_correct_msg(msg_in, nsym): +    if len(msg_in) > 255: +        raise ValueError("message too long") +    msg_out = list(msg_in)     # copy of message +    # find erasures +    erase_pos = [] +    for i in range(0, len(msg_out)): +        if msg_out[i] < 0: +            msg_out[i] = 0 +            erase_pos.append(i) +    if len(erase_pos) > nsym: +        raise ReedSolomonError("Too many erasures to correct") +    synd = rs_calc_syndromes(msg_out, nsym) +    if max(synd) == 0: +        return msg_out[:-nsym]  # no errors +    fsynd = rs_forney_syndromes(synd, erase_pos, len(msg_out)) +    err_pos = rs_find_errors(fsynd, len(msg_out)) +    if err_pos is None: +        raise ReedSolomonError("Could not locate error") +    rs_correct_errata(msg_out, synd, erase_pos + err_pos) +    synd = rs_calc_syndromes(msg_out, nsym) +    if max(synd) > 0: +        raise ReedSolomonError("Could not correct message") +    return msg_out[:-nsym] + + +#=================================================================================================== +# API +#=================================================================================================== +class RSCodec(object): +    """ +    A Reed Solomon encoder/decoder. After initializing the object, use ``encode`` to encode a  +    (byte)string to include the RS correction code, and pass such an encoded (byte)string to +    ``decode`` to extract the original message (if the number of errors allows for correct decoding). +    The ``nsym`` argument is the length of the correction code, and it determines the number of  +    error bytes (if I understand this correctly, half of ``nsym`` is correctable) +    """ +    def __init__(self, nsym=10): +        self.nsym = nsym + +    def encode(self, data): +        if isinstance(data, str): +            data = bytearray(data, "utf-8") +        chunk_size = 255 - self.nsym +        enc = bytearray() +        for i in range(0, len(data), chunk_size): +            chunk = data[i:i+chunk_size] +            enc.extend(rs_encode_msg(chunk, self.nsym)) +        return enc +     +    def decode(self, data): +        if isinstance(data, str): +            data = bytearray(data, "utf-8") +        dec = bytearray() +        for i in range(0, len(data), 255): +            chunk = data[i:i+255] +            dec.extend(rs_correct_msg(chunk, self.nsym)) +        return dec + + | 
