#! /usr/bin/env python
"""
This script performs the client portion of the file transfer protocol.
"""

import argparse
import logging
from multiprocessing.pool import ThreadPool
from struct import unpack
import signal
import sys

import symphony_ftp_packets as packets
sys.path.append('../')
import ll_ifc

LOG = logging.getLogger(__name__)
PORT = 128
# Extend the allowed port range so we can send messages at port 128
ll_ifc.ALLOWED_PORT_RANGE = range(0, 129)


def create_filename(file_id, module_address):
    return 'file_{}_{:x}'.format(module_address, file_id)


def file_transfer_protocol(dev):
    LOG.info("Performing file transfer with device %s", dev)
    dev.set_downlink_mode('always')
    state = 'Init'
    LOG.info(dev.get_irq_flags())
    dev.clear_irq_flags()
    LOG.info(dev.get_irq_flags())
    while True:
        pkt = get_packet(dev)

        LOG.info(" ")
        LOG.info("STATE: " + state)
        if state == 'Init':
            if isinstance(pkt, (packets.InitTransfer, packets.SegmentTransfer)):
                filename = create_filename(pkt.file_id, dev.get_unique_id())
                LOG.info("Creating new file " + filename)
                fd = open(filename, "wb")
                segments = []

                send_packet(dev, packets.ACKInitTransfer(ack_type=0x00, file_id=pkt.file_id,
                                                         file_version=pkt.file_version))
                state = 'Segment'
                dev.set_downlink_mode('always')

        elif state == 'Segment':
            if isinstance(pkt, packets.InitTransfer):
                fd.close()
                filename = create_filename(pkt.file_id, dev.get_unique_id())
                LOG.info("Creating new file " + filename)
                fd = open(filename, "wb")
                segments = []

                send_packet(dev, packets.ACKInitTransfer(ack_type=0x00, file_id=pkt.file_id,
                                                         file_version=pkt.file_version))
                state = 'Segment'
                dev.set_downlink_mode('always')
            elif isinstance(pkt, packets.CancelTransfer):
                LOG.info("Cancel file %s", filename)
                fd.close()
                state = 'Init'
            elif isinstance(pkt, packets.ApplyTransfer):
                LOG.warning('Ignoring Apply packet')
                state = 'Segment'
            elif isinstance(pkt, packets.SegmentTransfer):
                offset = pkt.segment_num * 107
                fd.seek(offset)
                fd.write(pkt.segment_payload)
                segments.append(pkt)
                segments_recvd = set([x.segment_num for x in segments])
                segments_all = set(range(0, pkt.segments_total))
                segments_needed = list(segments_all - segments_recvd)[:64]
                if pkt.segment_num == (pkt.segments_total - 1) or len(segments_needed) == 0:
                    if len(segments_needed) == 0:
                        ack_type = 0x00
                        fd.close()
                        state = 'Apply'
                    else:
                        ack_type = 0xFD

                    send_packet(dev, packets.ACKSegmentTransfer(
                        ack_type=ack_type, file_id=pkt.file_id, file_version=pkt.file_version,
                        segments_needed=segments_needed))
        elif state == 'Apply':
            if isinstance(pkt, packets.InitTransfer):
                LOG.warning('Received Init in Apply state')
                state = 'Init'
            elif isinstance(pkt, packets.CancelTransfer):
                LOG.warning("Cancel file %s", filename)
                fd.close()
                state = 'Init'
            elif isinstance(pkt, packets.ApplyTransfer):
                LOG.info("apply file %s", filename)
                fd.close()
                with open(filename, 'rb') as file:
                    data = file.read()
                calculated_crc = packets.crc(data[4:])
                (sent_crc, ) = unpack('<L', data[:4])
                if calculated_crc != sent_crc:
                    LOG.error("File CRC mismatch %s %s", calculated_crc, sent_crc)
                    state = 'Init'
                    continue
                else:
                    LOG.info("File CRC match. Re-writing file")
                    with open(filename, 'wb') as file:
                        file.write(data[4 * 4:])

                send_packet(dev, packets.ACKApplyTransfer(ack_type=0x00, file_id=pkt.file_id,
                                                          file_version=pkt.file_version))
                return
            elif isinstance(pkt, packets.SegmentTransfer):
                LOG.warning("Received Segment in Apply state.  Ignoring segment.")
            else:
                LOG.error('Unhandled packet type')
                state = 'Init'


def get_packet(dev):
    """ Waits indefinitely for a valid FTP packet from the module. """
    while True:
        msg = dev.wait_for_received_message()

        if msg.port != PORT:
            LOG.warning("Ignoring packet on port %s", msg.port)
            continue

        try:
            pkt = packets.symftp_parser(msg.payload)
        except packets.ParseError as err:
            LOG.warning("Error parsing packet: %s", err)
            continue

        LOG.debug("Received packet %r", pkt)
        return pkt


def send_packet(dev, pkt):
    """ Sends a FTP packet through the module. """
    LOG.debug("Sending packet %r", pkt)
    dev.send_message_checked(packets.build_packet(pkt), PORT)


def main():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--device', help='Path of the device to program. '
                        'Default is to find CP210x devices and receive the '
                        'file transfer for all of them.')
    parser.add_argument('--verbose', '-v', action='store_true')
    args = parser.parse_args()

    if args.verbose:
        LOG.setLevel(logging.DEBUG)

    if args.device:
        with ll_ifc.ModuleDriver(args.device) as dev:
            file_transfer_protocol(dev)
    else:
        with ll_ifc.get_all_modules() as devs:
            if len(devs) == 0:
                LOG.info("No modules found")
                return
            LOG.info(devs)

            original_sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
            pool = ThreadPool(len(devs))
            signal.signal(signal.SIGINT, original_sigint_handler)
            try:
                res = pool.map_async(file_transfer_protocol, devs)
                res.get(60000)  # Without the timeout this blocking call ignores all signals
            except KeyboardInterrupt:
                LOG.info("Caught keyboard interrupt")
                sys.exit()
            else:
                pool.close()
            pool.join()


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    main()
