#! /usr/bin/env python
"""
This script performs the server portion of the file transfer protocol.
"""

from __future__ import print_function
import argparse
from collections import namedtuple, defaultdict
from datetime import datetime, timedelta
import logging
try:
    from queue import Queue, Empty
except ImportError:
    from Queue import Queue, Empty
import random
from struct import pack
from threading import Thread
from time import sleep, time

import conductor
from conductor import ConductorAccount, AppToken
import dateutil.tz

import symphony_ftp_packets as packets

LOG = logging.getLogger(__name__)
LOG_FORMAT = '%(asctime)s [%(levelname)s] %(name)s: %(message)s'
LOG_DATEFMT = '%H:%M:%S'
MAX_SEGMENT_PAYLOAD = 21
INIT_MESSAGE_RETRIES = 3
TRANSFER_RETRIES = 3
APPLY_MESSAGE_RETRIES = 3

FileInfo = namedtuple('FileInfo', ['crc', 'file_id', 'file_version', 'file_size'])

PORT = 128
# Extend the allowed port range so we can send messages at port 128
conductor.ALLOWED_PORT_RANGE = range(0, 129)


def subscription_thread(subject, queue):
    """ Starts a thread that continuously refreshes the subscription in case it's closed. """

    def sub_callback(message):
        """ Puts the message in the thread-safe queue to be received synchronously. """
        if message.port == PORT:
            queue.put(message)

    # If the subscription terminates, just open it again
    while True:
        LOG.debug("Starting subscription for %s", subject)
        with subject.subscribe(sub_callback) as sub:
            while not sub.terminated and not sub.server_terminated:
                sleep(0.1)


def segment_file(file_data, file_id=None, file_version=None):
    """ Returns two things: a 'FileInfo' object and a list of segments. """
    file_size = len(file_data)
    if file_id is None:
        file_id = random.getrandbits(32)
    if file_version is None:
        file_version = random.getrandbits(32)

    # Create the file header
    header_without_crc = pack('<LLL', file_size, file_id, file_version)
    data_with_almost_header = header_without_crc + file_data
    file_crc = packets.crc(data_with_almost_header)
    data_with_header = pack('<L', file_crc) + data_with_almost_header

    file_info = FileInfo(file_crc, file_id, file_version, file_size)

    chunk_size = 128 - MAX_SEGMENT_PAYLOAD
    chunks = [data_with_header[i:i + chunk_size]
              for i in range(0, len(data_with_header), chunk_size)]

    segments = [packets.SegmentTransfer(file_size=file_size, file_id=file_id,
                                        file_version=file_version, segment_num=segment_num,
                                        segments_total=len(chunks), segment_payload=segment_payload)
                for segment_num, segment_payload in enumerate(chunks)]

    return file_info, segments


class TransferInfo(object):
    """ Contains the context for an ongoing file transfer. """

    def __init__(self, subject, file_data, file_id=None, file_version=None, mailbox_subjects=None,
                 gateway=None, mailbox_period=None):
        self.subject = subject
        self.mailbox_subjects = mailbox_subjects
        self.gateway = gateway

        self.state_methods = {
            'Init': self._init,
            'Transfer': self._transfer,
            'Apply': self._apply_file,
        }

        self.state = 'Init'
        self.modules_done = defaultdict(bool)
        self.modules_applied = defaultdict(bool)
        # Holds Conductor handles to segment messages
        self.full_transfers = []
        # Holds Conductor handles to all other messages
        self.command_ids = []
        message_timeout = mailbox_period if mailbox_period is not None else 10
        self.msg_ttl_s = message_timeout
        self.init_timeout_s = message_timeout
        self.apply_timeout_s = message_timeout
        self.transfer_retry_timeout_s = 60
        self.transfer_poll_period_s = 10

        self.file_info, self.segments = segment_file(file_data, file_id, file_version)
        LOG.info("Initialized file transfer for %s", self.file_info)

    def run(self, queue):
        """ Runs through the state machine until completion. """
        while self.state is not 'Done':
            self.state = self.state_methods[self.state](queue)

    def all_modules_received(self):
        """
        Returns a bool indicating whether all modules have finished receiving the file.
        Returns False if there have been no modules.
        """
        return self.modules_done and all(self.modules_done.values())

    def all_modules_applied(self):
        """
        Returns a bool indicating whether all modules have finished applying the file.
        Returns False if there have been no modules.
        """
        return self.modules_applied and all(self.modules_applied.values())

    def _send_message(self, payload):
        """
        Send messages to nodes.  This handles the case where nodes might be in downlink
        mailbox mode, e.g. init and apply
        """
        command_ids = []
        subjects = self.mailbox_subjects or [self.subject]
        for subject in subjects:
            ret = subject.send_message(payload, gateway_addr=self.gateway, acked=False,
                                       time_to_live_s=self.msg_ttl_s, port=PORT)
            LOG.debug('%s, %s', subject, ret)
            command_ids.append(ret)
        return command_ids

    def _send_segments(self, segments_needed=None):
        """ Posts all of the segments as downlink messages and stores the message handles. """
        # The segments are queued up immediately, so the segments at the end of the file transfer
        # will be queued at the gateway.  The time-to-live for each segment is adjusted assuming
        # 2 packets per frame.
        ttl_s = 5.0
        segments = [self.segments[i] for i in segments_needed] if segments_needed else self.segments
        transfer = []
        for segment in segments:
            command_id = self.subject.send_message(
                packets.build_packet(segment), gateway_addr=self.gateway, acked=False,
                time_to_live_s=ttl_s, port=PORT)
            LOG.debug('segment %d/%d: %s', segment.segment_num, segment.segments_total, command_id)
            transfer.append(command_id)
            ttl_s += 1.0

        self.full_transfers.append(transfer)

    def _latest_transfer_complete(self, timeout=0):
        """
        Returns a boolean indicating whether all of the segments in the latest
        transfer have been sent over the air.
        """
        last_msg = self.full_transfers[-1][-1]
        events = last_msg.get_events()

        expired = all(
            any(status == 'Expired' for (status, _) in route) for route in events.values())
        if expired:
            LOG.warning("Last message of transfer expired and was not sent. "
                        "Considering transfer complete")
            return True

        sent_times = [date for route in events.values() for (status, date) in route
                      if status == 'Sent']
        if not sent_times:
            return False

        latest_sent_time = max(sent_times)
        time_since_sent = datetime.now(dateutil.tz.tzutc()) - latest_sent_time
        return time_since_sent > timedelta(seconds=timeout)

    def _send_init_transfer(self):
        packet = packets.InitTransfer(file_size=self.file_info.file_size,
                                      file_id=self.file_info.file_id,
                                      file_version=self.file_info.file_version)
        command_ids = self._send_message(packets.build_packet(packet))
        self.command_ids.extend(command_ids)
        LOG.info('%s', command_ids)

    def _init(self, queue):

        for init_attempt in range(INIT_MESSAGE_RETRIES):
            LOG.debug("Attempt number %s of %s for init", init_attempt + 1, INIT_MESSAGE_RETRIES)
            self._send_init_transfer()
            deadline = time() + self.init_timeout_s
            while True:
                try:
                    module, packet = get_packet(queue, deadline)
                except Empty:
                    break

                next_state = self._init_handle_packet(module, packet)
                if next_state is not None and next_state != self.state:
                    return next_state

        raise ExhaustedRetriesError("Too many retries in init")

    def _init_handle_packet(self, module, packet):
        if isinstance(packet, packets.InitTransfer):
            LOG.warning('Cancel received')
            raise CanceledError('File transfer canceled by endpoint')
        elif isinstance(packet, packets.ACKInitTransfer):
            if packets.ACK_TYPES[packet.ack_type] == "ACK":
                LOG.info("Received init ack from %s", module)
                self.modules_done[module]
                self.modules_applied[module]
                self._send_segments()
                return 'Transfer'
            else:
                raise CanceledError("Init NACK ({}) from {}".format(packet.ack_type, module))
        elif isinstance(packet, packets.ACKSegmentTransfer):
            ack_type = packets.ACK_TYPES[packet.ack_type]
            if ack_type == "ACK":
                LOG.info('File success for %s', module)
                self.modules_done[module] = True
                return 'Done'
            elif ack_type == "NACK_SEGMENT":
                self.modules_done[module]
                self.modules_applied[module]
                LOG.info('segments_needed %s', packet.segments_needed)
                self._send_segments(segments_needed=packet.segments_needed)
                return 'Transfer'
            elif ack_type == "NACK":
                self.modules_done[module]
                self.modules_applied[module]
                self._send_segments()
                return 'Transfer'
        elif isinstance(packet, packets.ACKApplyTransfer):
            LOG.warning('Apply received in init state for %s', module)
            self.modules_done[module] = True
            self.modules_applied[module] = True
            return 'Apply'

    def _transfer(self, queue):
        num_transfers = 1
        while True:
            try:
                module, packet = get_packet(queue, deadline=time() + self.transfer_poll_period_s)
                if isinstance(packet, packets.InitTransfer):
                    LOG.warning('Cancel received')
                    raise CanceledError('File transfer canceled by endpoint')
                elif isinstance(packet, packets.ACKInitTransfer):
                    if packets.ACK_TYPES[packet.ack_type] == "ACK":
                        self.modules_done[module]
                        self.modules_applied[module]
                        LOG.info('Received ACK_INIT in the middle of transfer')
                        #TODO - might need to do something here for multi-cast transfers
                elif isinstance(packet, packets.ACKSegmentTransfer):
                    ack_type = packets.ACK_TYPES[packet.ack_type]
                    if ack_type == "ACK":
                        LOG.info('Transfer complete for %s', module)
                        self.modules_done[module] = True
                        self.modules_applied[module]
                    elif ack_type == "NACK_SEGMENT":
                        LOG.info('segments_needed ' + str(packet.segments_needed))
                        self._send_segments(segments_needed=packet.segments_needed)
                        self.modules_done[module]
                        self.modules_applied[module]
                    elif ack_type == "NACK":
                        self.modules_done[module]
                        self.modules_applied[module]
                        if self._latest_transfer_complete():
                            self._send_segments()
                elif isinstance(packet, packets.ACKApplyTransfer):
                    LOG.warning('Apply received in transfer state for %s', module)
                    self.modules_done[module] = True
                    self.modules_applied[module] = True
            except Empty:
                LOG.debug('timeout waiting on queue')

            if self.all_modules_received():
                return 'Apply'
            elif self._latest_transfer_complete(timeout=self.transfer_retry_timeout_s):
                if num_transfers > TRANSFER_RETRIES:
                    raise ExhaustedRetriesError('Never received the ACK_Transfer')
                LOG.info('Retrying full transfer due to timeout')
                self._send_segments()
                num_transfers = num_transfers + 1

    def _send_apply(self):
        packet = packets.ApplyTransfer(file_size=self.file_info.file_size,
                                       file_id=self.file_info.file_id,
                                       file_version=self.file_info.file_version)
        command_ids = self._send_message(packets.build_packet(packet))
        self.command_ids.extend(command_ids)
        LOG.info('%s', command_ids)

    def _apply_file(self, queue):

        for apply_attempt in range(APPLY_MESSAGE_RETRIES):
            LOG.debug("Attempt number %s of %s for apply", apply_attempt + 1, APPLY_MESSAGE_RETRIES)
            self._send_apply()
            deadline = time() + self.apply_timeout_s
            while True:
                try:
                    module, packet = get_packet(queue, deadline)
                except Empty:
                    break

                next_state = self._apply_handle_packet(module, packet)
                if next_state is not None and next_state != self.state:
                    return next_state

                if self.all_modules_applied():
                    return 'Done'

        raise ExhaustedRetriesError("Too many retries in apply")

    def _apply_handle_packet(self, module, packet):
        self.modules_applied[module]
        if isinstance(packet, packets.InitTransfer):
            LOG.warning('Cancel received')
            raise CanceledError('File transfer canceled by endpoint')
        elif isinstance(packet, packets.ACKInitTransfer):
            if packets.ACK_TYPES[packet.ack_type] == "ACK":
                LOG.warning('Received ACK_INIT in the Apply State')
                #TODO - might need to do something here for multi-cast transfers
        elif isinstance(packet, packets.ACKSegmentTransfer):
            ack_type = packets.ACK_TYPES[packet.ack_type]
            if ack_type == "ACK":
                LOG.warning('Transfer complete from %s in Apply state', module)
            elif ack_type == "NACK_SEGMENT":
                LOG.warning('segments_needed in Apply state' + str(packet.segments_needed))
                self._send_segments(segments_needed=packet.segments_needed)
                return 'Transfer'
            elif ack_type == "NACK":
                LOG.warning('All segments_needed in Apply state')
                if self._latest_transfer_complete():
                    self._send_segments()
                return 'Transfer'
        elif isinstance(packet, packets.ACKApplyTransfer):
            self.modules_applied[module] = True


def int_or_hex(num):
    if num.startswith(('0x', '0X')):
        return int(num, 16)
    else:
        return int(num)


def mailbox_subjects_get(account, subject):
    if isinstance(subject, AppToken):
        return [x for x in account.get_modules() if x._data['registrationToken'] == str(subject)]
    else:
        return [subject]


def get_packet(queue, deadline):
    """
    Gets a message from a queue, and only returns correctly parsed SYMFTP packets.

    Returns a tuple of (module_address, parsed_packet)
    """

    while True:
        msg = queue_get_with_deadline(queue, deadline)
        if msg.port != PORT:
            LOG.warning("Ignoring packet with incorrect port: %s", msg.port)
            continue

        try:
            packet = packets.symftp_parser(bytearray.fromhex(msg.payload_hex))
        except packets.ParseError as err:
            LOG.warning("Unable to parse message from %s: %s", msg.module, err)
            continue

        LOG.debug("Received packet %r from %s", packet, msg.module)
        return (msg.module, packet)


def queue_get_with_deadline(queue, deadline=None):
    """
    Just like the regular 'get' method, but uses a deadline instead of a timeout.

    The deadline should be in the same units as `time.time`
    """
    timeout = None
    if deadline is not None:
        timeout = deadline - time()
        if timeout < 0:
            raise Empty

    return queue.get(timeout=timeout)


class FTPError(Exception):
    """ Base exception for FTP server. """


class CanceledError(FTPError):
    """ Thrown when the module cancels the transfer. """


class ExhaustedRetriesError(FTPError):
    """ Thrown when a maximum number of retries is reached. """


def start_subscription_thread(subject):
    """
    Starts the infinite subscription thread. Returns the queue that the thread
    is pushing to.
    """
    queue = Queue()
    thread = Thread(target=subscription_thread, args=(subject, queue))
    thread.daemon = True
    thread.start()
    LOG.info("Sleeping for 16 to let subscription initialize")
    sleep(16)  # Wait for subscription to be good and ready
    return queue


def main():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('username', help='The username to authenticate to Conductor')
    parser.add_argument('filename', help='path of file to be transferred')
    parser.add_argument('--file_id', '-i', default=None, type=int_or_hex,
                        help='Specify the file ID (32-bit integer) to be used for '
                        'this file transfer')
    parser.add_argument('--file_version', '-V', default=None, type=int_or_hex,
                        help='Specify the file version (32-bit integer) to be '
                        'used for this file transfer')
    parser.add_argument('--password', '-p', help='The Conductor account password. '
                        'This is optional: if omitted, the script will prompt securely for '
                        'the password.')
    parser.add_argument('--gateway', '-g', default=None,
                        help='Specify the gateway through with the transfer will be routed.')
    parser.add_argument('--verbose', '-v', action='store_true')
    group = parser.add_mutually_exclusive_group()
    group.add_argument('--module_id', '-m', default=None,
                       help='The module ID used for unicast transfers.')
    group.add_argument('--app_token', '-a', default=None,
                       help='The application token used for multi-cast transfers.')
    parser.add_argument('--mailbox_period', default=10, type=float,
                        help='For applications that use downlink mailbox mode '
                        'in their idle state, provide the worse-case mailbox check period. '
                        'This determines the time-to-live for the INIT and APPLY messages. '
                        'The default is None to indicate the nodes uses DOWNLINK_ALWAYS_ON mode.')
    args = parser.parse_args()

    LOG.setLevel(logging.DEBUG if args.verbose else logging.INFO)

    account = ConductorAccount(args.username, args.password)

    if args.module_id:
        subject = account.get_module(args.module_id)
    if args.app_token:
        subject = account.get_application_token(args.app_token)

    if args.mailbox_period is None:
        mailbox_subjects = None
    else:
        mailbox_subjects = mailbox_subjects_get(account, subject)

    with open(args.filename, 'rb') as file:
        file_data = file.read()

    start = time()
    queue = start_subscription_thread(subject)

    transfer_info = TransferInfo(subject, file_data, args.file_id, args.file_version,
                                 mailbox_subjects, args.gateway, args.mailbox_period)
    transfer_info.run(queue)

    LOG.info('Transfer time %s', time() - start)
    print('Transfer successful to modules {}'.format(list(transfer_info.modules_applied.keys())))


if __name__ == '__main__':
    logging.basicConfig(level=logging.WARNING, format=LOG_FORMAT, datefmt=LOG_DATEFMT)
    main()
