#!/usr/bin/env python3
#
# Copyright 2017 Ettus Research, National Instruments Company
#
# SPDX-License-Identifier: GPL-3.0-or-later
#
"""
N3XX Built-In Self Test (BIST)

Will work on all derivatives of the N3xx series.
"""

from __future__ import print_function
import os
import sys
import subprocess
import re
import socket
import select
import time
import json
from datetime import datetime
import argparse
from six import iteritems

# Timeout values are in seconds:
GPS_WARMUP_TIMEOUT = 70 # Data sheet says "about a minute"
GPS_LOCKOK_TIMEOUT = 2 # Data sheet says about 15 minutes. Because our test
                       # does not necessarily require GPS lock to pass, we
                       # reduce this value in order for the BIST to pass faster
                       # by default.

##############################################################################
# Aurora/SFP BIST code
##############################################################################
def get_sfp_bist_defaults():
    " Default dictionary for SFP/Aurora BIST dry-runs "
    return {
        'elapsed_time': 1.0,
        'max_roundtrip_latency': 0.8e-6,
        'throughput': 1000e6,
        'max_ber': 8.5e-11,
        'errors': 0,
        'bits': 12012486656,
    }

def run_aurora_bist(master, slave=None):
    """
    Spawn a BER test
    """
    from usrp_mpm import aurora_control
    from usrp_mpm.sys_utils.uio import UIO
    try:
        master_au_uio = UIO(label=master, read_only=False)
        master_au_uio.open()
        master_au_ctrl = aurora_control.AuroraControl(master_au_uio)
        if slave is not None:
            slave_au_uio = UIO(label=slave, read_only=False)
            slave_au_uio.open()
        slave_au_ctrl = None if slave is None else aurora_control.AuroraControl(
            slave_au_uio
        )
        return master_au_ctrl.run_ber_loopback_bist(
            duration=10,
            requested_rate=1300 * 8e6,
            slave=slave_au_ctrl,
        )
    except Exception as ex:
        print("Unexpected exception: {}".format(str(ex)))
        exit(1)
    finally:
        master_au_uio.close()
        if slave is not None:
            slave_au_uio.close()

def aurora_results_to_status(bist_results):
    """
    Convert a dictionary coming from AuroraControl BIST to one that we can use
    for this BIST
    """
    return bist_results['mst_errors'] == 0, {
        'elapsed_time': bist_results['time_elapsed'],
        'max_roundtrip_latency': bist_results['mst_latency_us'],
        'throughput': bist_results['approx_throughput'],
        'max_ber': bist_results['max_ber'],
        'errors': bist_results['mst_errors'],
        'bits': bist_results['mst_samps'],
    }

##############################################################################
# Helpers
##############################################################################
def post_results(results):
    """
    Given a dictionary, post the results.

    This will print the results as JSON to stdout.
    """
    print(json.dumps(
        results,
        sort_keys=True,
        indent=4,
        separators=(',', ': ')
    ))

def filter_results_for_lv(results):
    """
    The LabView JSON parser does not support a variety of things, such as
    nested dicts, and some downstream LV applications freak out if certain keys
    are not what they expect.
    This is a long hard-coded list of how results should look like for those
    cases. Note: This list needs manual supervision and attention for the case
    where either subsystems get renamed, or other architectural changes should
    occur.
    """
    lv_compat_format = {
        'ddr3': {
            'throughput': -1,
        },
        'gpsdo': {
            "class": "",
            "time": "",
            "ept": -1,
            "lat": -1,
            "lon": -1,
            "alt": -1,
            "epx": -1,
            "epy": -1,
            "epv": -1,
            "track": -1,
            "speed": -1,
            "climb": -1,
            "eps": -1,
            "mode": -1,
        },
        'tpm': {
            'tpm0_caps': "",
        },
        'sfp0_loopback': {
            'elapsed_time': -1,
            'max_roundtrip_latency': -1,
            'throughput': -1,
            'max_ber': -1,
            'errors': -1,
            'bits': -1,
        },
        'sfp1_loopback': {
            'elapsed_time': -1,
            'max_roundtrip_latency': -1,
            'throughput': -1,
            'max_ber': -1,
            'errors': -1,
            'bits': -1,
        },
        'gpio': {
            'write_patterns': [],
            'read_patterns': [],
        },
        'temp': {
            'fpga-thermal-zone': -1,
        },
        'fan': {
            'cooling_device0': -1,
            'cooling_device1': -1,
        },
    }
    # OK now go and brush up the results:
    def fixup_dict(result_dict, ref_dict):
        """
        Touches up result_dict according to ref_dict by the following rules:
        - If a key is in result_dict that is not in ref_dict, delete that
        - If a key is in ref_dict that is not in result_dict, use the value
          from ref_dict
        """
        ref_dict['error_msg'] = ""
        ref_dict['status'] = False
        result_dict = {
            k: v for k, v in iteritems(result_dict)
            if k in ref_dict or k in ('error_msg', 'status')
        }
        result_dict = {
            k: result_dict.get(k, ref_dict[k]) for k in ref_dict
        }
        return result_dict
    results = {
        testname: fixup_dict(testresults, lv_compat_format[testname]) \
                    if testname in lv_compat_format else testresults
        for testname, testresults in iteritems(results)
    }
    return results

def sock_read_line(my_sock, timeout=60, interval=0.1):
    """
    Read from a socket until newline. If there was no newline until the timeout
    occurs, raise an error. Otherwise, return the line.
    """
    line = b''
    end_time = time.time() + timeout
    while time.time() < end_time:
        socket_ready = select.select([my_sock], [], [], 0)[0]
        if socket_ready:
            next_char = my_sock.recv(1)
            if next_char == b'\n':
                return line.decode('ascii')
            line += next_char
        else:
            time.sleep(interval)
    raise RuntimeError("sock_read_line() exceeded read timeout!")

def poll_with_timeout(state_check, timeout_ms, interval_ms):
    """
    Calls state_check() every interval_ms until it returns a positive value, or
    until a timeout is exceeded.

    Returns True if state_check() returned True within the timeout.
    """
    max_time = time.time() + (float(timeout_ms) / 1000)
    interval_s = float(interval_ms) / 1000
    while time.time() < max_time:
        if state_check():
            return True
        time.sleep(interval_s)
    return False

def expand_options(option_list):
    """
    Turn a list ['foo=bar', 'spam=eggs'] into a dictionary {'foo': 'bar',
    'spam': 'eggs'}.
    """
    return dict(x.split('=') for x in option_list)

##############################################################################
# Bist class
##############################################################################
class N3XXBIST(object):
    """
    BIST Tool for the USRP N3xx series
    """
    # This defines special tests that are really collections of other tests.
    collections = {
        'standard': ["ddr3", "gpsdo", "rtc", "temp", "clock_int", "tpm"],
        'extended': "*",
    }

    @staticmethod
    def make_arg_parser():
        """
        Return arg parser
        """
        parser = argparse.ArgumentParser(
            description="N3xx BIST Tool",
        )
        parser.add_argument(
            '-n', '--dry-run', action='store_true',
            help="Fake out the tests. All tests will return a valid" \
                 " response, but will not actually interact with hardware.",
        )
        parser.add_argument(
            '-v', '--verbose', action='store_true',
            help="Crank up verbosity level",
        )
        parser.add_argument(
            '--debug', action='store_true',
            help="For debugging this tool.",
        )
        parser.add_argument(
            '--option', '-o', action='append', default=[],
            help="Option for individual test.",
        )
        parser.add_argument(
            '--lv-compat', action='store_true',
            help="Provides compatibility with the LV JSON parser. Don't run "
                 "this mode unless you know what you're doing. The JSON "
                 "output does not necessarily reflect the actual system "
                 "status when using this mode.",
        )
        parser.add_argument(
            'tests',
            help="List the tests that should be run",
            nargs='+', # There has to be at least one
        )
        return parser

    def __init__(self):
        self.args = N3XXBIST.make_arg_parser().parse_args()
        self.args.option = expand_options(self.args.option)
        try:
            from usrp_mpm.periph_manager.n3xx import n3xx
            default_rev = n3xx.mboard_max_rev
        except ImportError:
            # This means we're in dry run mode or something like that, so just
            # pick something
            default_rev = 3
        self.mb_rev = int(self.args.option.get('mb_rev', default_rev))
        self.tests_to_run = set()
        for test in self.args.tests:
            if test in self.collections:
                for test in self.expand_collection(test):
                    self.tests_to_run.add(test)
            else:
                self.tests_to_run.add(test)
        try:
            # Keep this import here so we can do dry-runs without any MPM code
            from usrp_mpm import get_main_logger
            if not self.args.verbose:
                from usrp_mpm.mpmlog import WARNING
                get_main_logger().setLevel(WARNING)
            self.log = get_main_logger().getChild('main')
        except ImportError:
            print("No logging capability available.")

    def expand_collection(self, coll):
        """
        Return names of tests in a collection
        """
        tests = self.collections[coll]
        if tests == "*":
            tests = {x.replace('bist_', '')
                     for x in dir(self)
                     if x.find('bist_') == 0
                    }
        else:
            tests = set(tests)
        return tests

    def run(self):
        """
        Execute tests.

        Returns True on Success.
        """
        def execute_test(testname):
            """
            Actually run a test.
            """
            testmethod_name = "bist_{0}".format(testname)
            sys.stderr.write(
                "Executing test method: {0}\n\n".format(testmethod_name)
            )
            try:
                status, data = getattr(self, testmethod_name)()
                data['status'] = status
                data['error_msg'] = data.get('error_msg', '')
                return status, data
            except AttributeError:
                sys.stderr.write("Test not defined: {}\n".format(testname))
                return False, {}
            except Exception as ex:
                sys.stderr.write(
                    "Test {} failed to execute: {}\n".format(testname, str(ex))
                )
                if self.args.debug:
                    raise
                return False, {'error_msg': str(ex)}
        tests_successful = True
        result = {}
        for test in self.tests_to_run:
            status, result_data = execute_test(test)
            tests_successful = tests_successful and status
            result[test] = result_data
        if self.args.lv_compat:
            result = filter_results_for_lv(result)
        post_results(result)
        return tests_successful

#############################################################################
# BISTS
# All bist_* methods must return True/False success values!
#############################################################################
    def bist_rtc(self):
        """
        BIST for RTC (real time clock)

        Return dictionary:
        - date: Returns the current UTC time, with seconds-accuracy, in ISO 8601
                format, as a string. As if running 'date -Iseconds -u'.
        - time: Same time, but in seconds since epoch.

        Return status:
        Unless datetime throws an exception, returns True.
        """
        assert 'rtc' in self.tests_to_run
        utc_now = datetime.utcnow()
        return True, {
            'time': time.mktime(utc_now.timetuple()),
            'date': utc_now.replace(microsecond=0).isoformat() + "+00:00",
        }

    def bist_ddr3(self):
        """
        BIST for PL DDR3 DRAM
        Description: Calls a test to examine the speed of the DDR3. To be
        precise, it fires up a UHD session, which runs a DDR3 BiST internally.
        If that works, it'll return estimated throughput that was gathered
        during the DDR3 BiST.

        External Equipment: None

        Return dictionary:
        - throughput: The estimated throughput in bytes/s

        Return status:
        True if the DDR3 bist passed
        """
        assert 'ddr3' in self.tests_to_run
        if self.args.dry_run:
            return True, {'throughput': 1250e6}
        result = {}
        ddr3_bist_executor = 'uhd_usrp_probe --args addr=127.0.0.1'
        try:
            output = subprocess.check_output(
                ddr3_bist_executor,
                stderr=subprocess.STDOUT,
                shell=True,
            )
        except subprocess.CalledProcessError as ex:
            # Don't throw errors from uhd_usrp_probe
            output = ex.output
        output = output.decode("utf-8")
        mobj = re.search(r"Throughput: (?P<thrup>[0-9.]+)\s?MB", output)
        if mobj is not None:
            result['throughput'] = float(mobj.group('thrup')) * 1000
        else:
            result['throughput'] = 0
            result['error_msg'] = result.get('error_msg', '') + \
                                        "\n\nFailed match throughput regex!"
        return result.get('throughput', 0) > 1000e3, result

    def bist_gpsdo(self):
        """
        BIST for GPSDO
        Description: Returns GPS information
        External Equipment: None; Recommend attaching an antenna or providing
                           fake GPS information

        Return dictionary: A TPV dictionary as returned by gpsd.
        See also: http://www.catb.org/gpsd/gpsd_json.html

        Check for mode 2 or 3 to see if it's locked.
        """
        assert 'gpsdo' in self.tests_to_run
        if self.args.dry_run:
            return True, {
                "class": "TPV",
                "time": "2017-04-30T11:48:20.10Z",
                "ept": 0.005,
                "lat": 30.407899,
                "lon": -97.726634,
                "alt": 1327.689,
                "epx": 15.319,
                "epy": 17.054,
                "epv": 124.484,
                "track": 10.3797,
                "speed": 0.091,
                "climb": -0.085,
                "eps": 34.11,
                "mode": 3
            }
        from usrp_mpm.periph_manager import n3xx
        gpio_tca6424 = n3xx.TCA6424(self.mb_rev)
        # Turn on GPS, give some time to acclimatize
        gpio_tca6424.set("PWREN-GPS")
        time.sleep(5)
        gps_warmup_timeout = float(
            self.args.option.get('gps_warmup_timeout', GPS_WARMUP_TIMEOUT))
        gps_lockok_timeout = float(
            self.args.option.get('gps_lockok_timeout', GPS_LOCKOK_TIMEOUT))
        # Wait for WARMUP to go low
        sys.stderr.write(
            "Waiting for WARMUP to go low for up to {} seconds...\n".format(
                gps_warmup_timeout))
        if not poll_with_timeout(
                lambda: not gpio_tca6424.get('GPS-WARMUP'),
                gps_warmup_timeout*1000, 1000
            ):
            raise RuntimeError(
                "GPS-WARMUP did not go low within {} seconds!".format(
                    gps_warmup_timeout))
        sys.stderr.write("Chip is warmed up.\n")
        # Wait for LOCKOK. Data sheet says wait up to 15 minutes for GPS lock.
        sys.stderr.write(
            "Waiting for LOCKOK to go high for up to {} seconds...\n".format(
                gps_lockok_timeout))
        if not poll_with_timeout(
                lambda: gpio_tca6424.get('GPS-LOCKOK'),
                gps_lockok_timeout*1000,
                1000
            ):
            sys.stderr.write("No GPS-LOCKOK!\n")
        sys.stderr.write("GPS-SURVEY status: {}\n".format(
            gpio_tca6424.get('GPS-SURVEY')
        ))
        sys.stderr.write("GPS-PHASELOCK status: {}\n".format(
            gpio_tca6424.get('GPS-PHASELOCK')
        ))
        sys.stderr.write("GPS-ALARM status: {}\n".format(
            gpio_tca6424.get('GPS-ALARM')
        ))
        # Now read back response from chip
        my_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        my_sock.connect(('localhost', 2947))
        sys.stderr.write("Connected to GPSDO socket.\n")
        query_cmd = b'?WATCH={"enable":true,"json":true}'
        my_sock.sendall(query_cmd)
        sys.stderr.write("Sent query: {}\n".format(query_cmd))
        sock_read_line(my_sock, timeout=10)
        sys.stderr.write("Received initial newline.\n")
        result = {}
        while result.get('class', None) != 'TPV':
            json_result = sock_read_line(my_sock, timeout=60)
            sys.stderr.write(
                "Received JSON response: {}\n\n".format(json_result)
            )
            result = json.loads(json_result)
        my_sock.sendall(b'?WATCH={"enable":false}')
        my_sock.close()
        # If we reach this line, we have a valid result and the chip responded.
        # However, it doesn't necessarily mean we had a GPS lock.
        return True, result

    def bist_tpm(self):
        """
        BIST for TPM (Trusted Platform Module)

        This reads the caps value for all detected TPM devices.

        Return dictionary:
        - tpm<N>_caps: TPM manufacturer and version info. Is a multi-line
                       string.

        Return status: True if exactly one TPM device is detected.
        """
        assert 'tpm' in self.tests_to_run
        if self.args.dry_run:
            return True, {
                'tpm0_caps': "Fake caps value\n\nVersion 0.0.0",
            }
        result = {}
        props_to_read = ('caps',)
        base_path = '/sys/class/tpm'
        for tpm_device in os.listdir(base_path):
            if tpm_device.startswith('tpm'):
                for key in props_to_read:
                    result['{}_{}'.format(tpm_device, key)] = open(
                        os.path.join(base_path, tpm_device, key), 'r'
                    ).read().strip()
        return len(result) == 1, result

    def bist_clock_int(self):
        """
        BIST for clock lock from internal (25 MHz) source.
        Description: Checks to see if the daughtercard can lock to an internal
        clock source.

        External Equipment: None
        Return dictionary:
        - <sensor-name>:
          - locked: Boolean lock status

        There can be multiple ref lock sensors; for a pass condition they all
        need to be asserted.
        """
        assert 'clock_int' in self.tests_to_run
        if self.args.dry_run:
            return True, {'ref_locked': True}
        # FIXME implement
        sys.stderr.write("Test not implemented.\n")
        return True, {}

    def bist_clock_ext(self):
        """
        BIST for clock lock from external source. Note: This test requires a
        connected daughterboard with a 'ref lock' sensor available.

        Description: Checks to see if the daughtercard can lock to the external
        reference clock.

        External Equipment: 10 MHz reference Source connected to "ref in".

        Return dictionary:
        - <sensor-name>:
          - locked: Boolean lock status

        There can be multiple ref lock sensors; for a pass condition they all
        need to be asserted.
        """
        assert 'clock_ext' in self.tests_to_run
        if self.args.dry_run:
            return True, {'ref_locked': True}
        # FIXME implement
        sys.stderr.write("Test not implemented.\n")
        return True, {}

    def bist_sfp0_loopback(self):
        """
        BIST for SFP+ ports:
        Description: Uses one SFP+ port to test the other. Pipes data out
        through one SFP, back to the other.

        External Equipment: Loopback module in SFP0 required
        required.

        Return dictionary:
        - elapsed_time: Float value, test time in seconds
        - max_roundtrip_latency: Float value, max roundtrip latency in seconds
        - throughput: Approximate data throughput in bytes/s
        - max_ber: Estimated maximum BER, float value.
        - errors: Number of errors
        - bits: Number of bits that were transferred
        """
        if self.args.dry_run:
            return True, get_sfp_bist_defaults()
        sfp_bist_results = run_aurora_bist(master='misc-auro-regs0')
        return aurora_results_to_status(sfp_bist_results)

    def bist_sfp1_loopback(self):
        """
        BIST for SFP+ ports:
        Description: Uses one SFP+ port to test the other. Pipes data out
        through one SFP, back to the other.

        External Equipment: Loopback module in SFP1 required
        required.

        Return dictionary:
        - elapsed_time: Float value, test time in seconds
        - max_roundtrip_latency: Float value, max roundtrip latency in seconds
        - throughput: Approximate data throughput in bytes/s
        - max_ber: Estimated maximum BER, float value.
        - errors: Number of errors
        - bits: Number of bits that were transferred
        """
        if self.args.dry_run:
            return True, get_sfp_bist_defaults()
        sfp_bist_results = run_aurora_bist(master='misc-auro-regs1')
        return aurora_results_to_status(sfp_bist_results)

    def bist_sfp_loopback(self):
        """
        BIST for SFP+ ports:
        Description: Uses one SFP+ port to test the other. Pipes data out
        through one SFP, back to the other.

        External Equipment: Loopback cable between the two SFP+ ports
        required.

        Return dictionary:
        - elapsed_time: Float value, test time in seconds
        - max_roundtrip_latency: Float value, max roundtrip latency in seconds
        - throughput: Approximate data throughput in bytes/s
        - max_ber: Estimated maximum BER, float value.
        - errors: Number of errors
        - bits: Number of bits that were transferred
        """
        if self.args.dry_run:
            return True, get_sfp_bist_defaults()
        sfp_bist_results = run_aurora_bist(
            master='misc-auro-regs0',
            slave='misc-auro-regs1',
        )
        return aurora_results_to_status(sfp_bist_results)

    def bist_gpio(self):
        """
        BIST for GPIO
        Description: Writes and reads the values to the GPIO

        Needed Equipment: External loopback as follows
            GPIO
            0<->6
            1<->7
            2<->8
            3<->9
            4<->10
            5<->11

        Return dictionary:
        - write_patterns: A list of patterns that were written
        - read_patterns: A list of patterns that were read back
        """
        assert 'gpio' in self.tests_to_run
        # patterns = list(range(64))
        GPIO_WIDTH = 12
        patterns = range(64)
        if self.args.dry_run:
            return True, {
                'write_patterns': list(patterns),
                'read_patterns': list(patterns),
            }
        from usrp_mpm.periph_manager import n3xx
        gpio_tca6424 = n3xx.TCA6424(self.mb_rev)
        gpio_tca6424.set("FPGA-GPIO-EN")
        mb_regs = n3xx.MboardRegsControl(n3xx.n3xx.mboard_regs_label, self.log)
        mb_regs.set_fp_gpio_master(0xFFF)
        # Allow some time for the front-panel GPIOs to become usable
        time.sleep(.5)
        ddr1 = 0x03f
        ddr2 = 0xfc0
        def _run_gpio(ddr, patterns):
            " Run a GPIO test for a given set of patterns "
            gpio_ctrl = n3xx.FrontpanelGPIO(ddr)
            for pattern in patterns:
                gpio_set_all(gpio_ctrl, pattern, GPIO_WIDTH, ddr)
                time.sleep(0.1)
                gpio_rb = gpio_ctrl.get_all()
                if  pattern != gpio_rb:
                    return False, {'write_patterns': [pattern],
                                   'read_patterns': [gpio_rb]}
            return True, {'write_patterns': list(patterns),
                          'read_patterns': list(patterns)}
        status, data = _run_gpio(ddr1, patterns)
        if  not status:
            return status, data
        status, data = _run_gpio(ddr2, patterns)
        return status, data

    def bist_temp(self):
        """
        BIST for temperature sensors
        Description: Reads the temperature sensors on the motherboards and
        returns their values in mC

        Return dictionary:
        - <thermal-zone-name>: temp in mC
        """
        assert 'temp' in self.tests_to_run
        if self.args.dry_run:
            return True, {'fpga-thermal-zone': 30000}
        import pyudev
        context = pyudev.Context()
        result = {
            device.attributes.get('type').decode('ascii'): \
                    int(device.attributes.get('temp').decode('ascii'))
            for device in context.list_devices(subsystem='thermal')
            if 'temp' in device.attributes.available_attributes \
                    and device.attributes.get('temp') is not None
        }
        if len(result) < 1:
            result['error_msg'] = "No temperature sensors found!"
        return 'error_msg' not in result, result

    def bist_fan(self):
        """
        BIST for temperature sensors
        Description: Reads the RPM values of the fans on the motherboard

        Return dictionary:
        - <fan-name>: Fan speed in RPM

        External Equipment: None
        """
        assert 'fan' in self.tests_to_run
        if self.args.dry_run:
            return True, {'cooling_device0': 10000, 'cooling_device1': 10000}
        import pyudev
        context = pyudev.Context()
        result = {
            device.sys_name: int(device.attributes.get('cur_state'))
            for device in context.list_devices(subsystem='thermal')
            if 'cur_state' in device.attributes.available_attributes \
                    and device.attributes.get('cur_state') is not None
        }
        return len(result) == 2, result


def gpio_set_all(gpio_bank, value, gpio_size, ddr_mask):
    """Helper function for set gpio.
    What this function do is take decimal value and convert to a binary string
    then try to set those individual bits to the gpio_bank.
    Arguments:
        gpio_bank  -- gpio bank type.
        value -- value to set onto gpio bank.
        gpio_size -- size of the gpio bank
        ddr_mask  -- data direction register bit mask. 0 is input; 1 is output.
    """
    ddr_size = bin(ddr_mask).count("1")
    value_bitstring = ('{0:0' + str(ddr_size) + 'b}').format(value)[-(gpio_size):]
    ddr_bitstring = ('{0:0' + str(gpio_size) + 'b}').format(ddr_mask)[-(gpio_size):]
    for i in range(gpio_size):
        if ddr_bitstring[gpio_size - 1 - i] == "1":
            gpio_bank.set(i, value_bitstring[i % ddr_size])

##############################################################################
# main
##############################################################################
def main():
    " Go, go, go! "
    return N3XXBIST().run()

if __name__ == '__main__':
    exit(not main())