"""
This module defines the packets (both uplink and downlink) that are used in the
Symphony FTP protocol.
"""

import struct
from zlib import crc32

from scapy.packet import Packet, Raw, bind_layers
from scapy.fields import (ByteEnumField, ConditionalField, LEIntField, LEShortField, StrField,
                          FieldListField)

# SYMFTP packet types
PACKET_TYPES = {
    0x00: "TRANSFER_INIT",
    0x01: "TRANSFER_CANCEL",
    0x02: "TRANSFER_APPLY",
    0x03: "TRANSFER_SEGMENT",
    0x04: "ACK_INIT",
    0x05: "ACK_SEGMENT",
    0x06: "ACK_APPLY",
}

# SYMFTP_ack types
ACK_TYPES = {
    0x00: "ACK",
    0xFF: "NACK",
    0xFD: "NACK_SEGMENT",
}


class SymftpPacket(Packet):
    """
    This class provides the fields that are common to all SYMFTP packets.
    It's implemented as a separate "layer" in scapy terms.
    """
    fields_desc = [
        ByteEnumField("pkt_type", None, PACKET_TYPES),
        LEIntField("crc", None),
    ]

    def post_build(self, pkt, pay):
        """
        Sets the CRC of the packet, if it's not alread set.

        This is a scapy-specific function, and it's called in the 'build' method.
        """
        if self.crc is None:
            _crc = crc(pkt[5:] + pay)
            pkt = pkt[:1] + struct.pack('<L', _crc) + pkt[5:]

        return pkt + pay

    def validate_crc(self):
        """ Validates the CRC. Will throw `CrcError` if the CRC is not what it should be. """
        # build the packet to calculate the CRC, then parse it to get the CRC
        # Save off the CRC, then set it to None so that the CRC will be calculated.
        original_crc = self.crc
        self.crc = None

        try:
            # Build the packet then parse, which will have the calculated CRC
            new_pkt = self.__class__(self.build())
            if original_crc != new_pkt.crc:
                raise CrcError(self.crc, new_pkt.crc)
        finally:
            # Put back the original CRC
            self.crc = original_crc


class InitTransfer(Packet):
    """initialize file transfer packet
    Packet to indicate to the end-nodes that a file transfer is about to begin
    and that applicable end-nodes should reply with an ACK or NACK
    """
    name = "Initialize Transfer"
    fields_desc = [
        LEIntField("file_id", 0x00000000),
        LEIntField("file_version", 0x00000000),
        LEIntField("file_size", 0x00000000),
    ]


class CancelTransfer(Packet):
    """cancel file transfer packet
    Packet to indicate to an end-node or server is canceling the file transfer
    """
    name = "Cancel Transfer"
    fields_desc = [
        LEIntField("file_id", 0x00000000),
        LEIntField("file_version", 0x00000000),
        LEIntField("file_size", 0x00000000),
    ]


class ApplyTransfer(Packet):
    """apply file transfer packet
    Packet to indicate to an end-node to apply a transferred file
    """
    name = "Apply Transfer"
    fields_desc = [
        LEIntField("file_id", 0x00000000),
        LEIntField("file_version", 0x00000000),
        LEIntField("file_size", 0x00000000),
    ]


class SegmentTransfer(Packet):
    """file transfer segment packet
    File segment from server to client
    """
    name = "Segment Transfer"
    fields_desc = [
        LEIntField("file_id", 0x00000000),
        LEIntField("file_version", 0x00000000),
        LEIntField("file_size", 0x00000000),
        LEShortField("segment_num", 0x0000),
        LEShortField("segments_total", 0x0000),
        StrField("segment_payload", ""),
    ]


class AckTypePacket(Packet):
    """
    A class that provides a validation step for packets that have an ack_type that needs
    to be an element in ACK_TYPES.
    """

    def post_dissect(self, string):
        """ Runs after dissection. Validates that the ack_type is a valid one. """
        if self.ack_type not in ACK_TYPES:
            raise ParseError("Unknown ack type: %s", self.ack_type)

        return string


class ACKInitTransfer(AckTypePacket):
    """ACK/NACK for Transfer Init
    Acknowledge transfer initialization
    """
    name = "ACK/NACK Transfer Initialization"
    fields_desc = [
        ByteEnumField("ack_type", 0x00, ACK_TYPES),
        LEIntField("file_id", 0x00000000),
        LEIntField("file_version", 0x00000000),
    ]


class ACKSegmentTransfer(AckTypePacket):
    """ACK/NACK for Segment
    Acknowledge segment
    """
    name = "ACK/NACK Segment Transfer"
    fields_desc = [
        ByteEnumField("ack_type", 0x00, ACK_TYPES),
        LEIntField("file_id", 0x00000000),
        LEIntField("file_version", 0x00000000),
        ConditionalField(
            FieldListField("segments_needed", [], LEShortField("segment", 0)),
            lambda pkt: pkt.ack_type == 0xFD),
    ]


class ACKApplyTransfer(AckTypePacket):
    """ACK/NACK for Transfer Apply"""
    name = "ACK/NACK Transfer Apply"
    fields_desc = [
        ByteEnumField("ack_type", 0x00, ACK_TYPES),
        LEIntField("file_id", 0x00000000),
        LEIntField("file_version", 0x00000000),
    ]


bind_layers(SymftpPacket, InitTransfer, pkt_type=0x00)
bind_layers(SymftpPacket, CancelTransfer, pkt_type=0x01)
bind_layers(SymftpPacket, ApplyTransfer, pkt_type=0x02)
bind_layers(SymftpPacket, SegmentTransfer, pkt_type=0x03)
bind_layers(SymftpPacket, ACKInitTransfer, pkt_type=0x04)
bind_layers(SymftpPacket, ACKSegmentTransfer, pkt_type=0x05)
bind_layers(SymftpPacket, ACKApplyTransfer, pkt_type=0x06)


def symftp_parser(buf):
    """Create a packet from the buffer provided"""
    pkt = SymftpPacket()
    pkt.dissect(bytes(buf))

    if isinstance(pkt.payload, Raw):
        raise ParseError("Unknown packet type")

    pkt.validate_crc()

    return pkt.payload


def build_packet(pkt):
    """
    Adds the packet type and CRC to the packet.

    Returns the bytes to be sent over the air.
    """
    return (SymftpPacket() / pkt).build()


class ParseError(Exception):
    """ Base class for all errors that can happen during parsing of a packet. """
    pass


class CrcError(ParseError):
    """ Thrown when a packet is parsed with an invalid CRC. """
    pass


def crc(data):
    """
    This is the CRC function used for both each packet and the whole file.

    Returns an unsigned 32-bit integer.
    """
    # Note: the bitwise AND forces the crc to be unsigned, which is needed for python2
    return crc32(data) & 0xFFFFFFFF
