#!/usr/bin/env python3
# encoding: utf-8
"""
bgp

Created by Thomas Mangin
Copyright (c) 2013-2017 Exa Networks. All rights reserved.
License: 3-clause BSD. (See the COPYRIGHT file)
"""

import os
import pwd
import sys
import time
import errno
import socket
import threading
import signal
try:
    import asyncore
except:
    import py_asyncore as asyncore
import subprocess
from struct import unpack

SIGNAL = dict([(name, getattr(signal, name)) for name in dir(signal) if name.startswith('SIG')])


def flushed(*output):
    print(' '.join(str(_) for _ in output))
    sys.stdout.flush()


def bytestream(value):
    return ''.join(['%02X' % _ for _ in value])


def dump(value):
    def spaced(value):
        even = None
        for v in value:
            if even is False:
                yield ' '
            yield '%02X' % v
            even = not even

    return ''.join(spaced(value))


def cdr_to_length(cidr):
    if cidr > 24:
        return 4
    if cidr > 16:
        return 3
    if cidr > 8:
        return 2
    if cidr > 0:
        return 1
    return 0


class BGPHandler(asyncore.dispatcher_with_send):
    counter = 0

    keepalive = bytearray([0xFF,] * 16 + [0x0, 0x13, 0x4])

    _name = {
        b'\x01': 'OPEN',
        b'\x02': 'UPDATE',
        b'\x03': 'NOTIFICATION',
        b'\x04': 'KEEPALIVE',
    }

    def signal(self, myself, signal_name='SIGUSR1'):
        signal_number = SIGNAL.get(signal_name, '')
        if not signal_number:
            self.announce('invalid signal name in configuration : %s' % signal_name)
            self.announce('options are: %s' % ','.join(SIGNAL.keys()))
            sys.exit(1)

        conf_name = sys.argv[1].split('/')[-1].split('.')[0]

        processes = []

        for line in os.popen("/bin/ps x"):
            low = line.strip().lower()
            if not low:
                continue
            if 'python' not in low and 'pypy' not in low:
                continue

            cmdline = line.strip().split()[4:]
            pid = line.strip().split()[0]

            if len(cmdline) > 1 and not cmdline[1].endswith('/bgp.py'):
                continue

            if conf_name not in cmdline[-1]:
                continue

            if not cmdline[-1].endswith('.conf'):
                continue

            processes.append(pid)

        if len(processes) == 0:
            self.announce('no running process found, this should not happend, quitting')
            sys.exit(1)

        if len(processes) > 1:
            self.announce('more than one process running, this should not happend, quitting')
            sys.exit(1)

        try:
            self.announce('sending signal %s to ExaBGP (pid %s)\n' % (signal_name, processes[0]))
            os.kill(int(processes[0]), signal_number)
        except Exception as exc:
            self.announce('\n     failed: %s' % str(exc))

    def kind(self, header):
        return header[18]

    def isupdate(self, header):
        return header[18] == 2

    def isnotification(self, header):
        return header[18] == 4

    def name(self, header):
        return self._name.get(header[18], 'SOME WEIRD RFC PACKET')

    def routes(self, header, body):
        len_w = unpack('!H', body[0:2])[0]
        withdrawn = bytearray([_ for _ in body[2 : 2 + len_w]])
        len_a = unpack('!H', body[2 + len_w : 2 + len_w + 2])[0]
        announced = bytearray([_ for _ in body[2 + len_w + 2 + len_a :]])

        if not withdrawn and not announced:
            if len(body) == 4:
                yield 'eor:1:1'
            elif len(body) == 11:
                yield 'eor:%d:%d' % (body[-2], body[-1])
            else:  # undecoded MP route
                yield 'mp:'
            return

        while withdrawn:
            cdr, withdrawn = withdrawn[0], withdrawn[1:]
            size = cdr_to_length(cdr)
            r = [0, 0, 0, 0]
            for index in range(size):
                r[index], withdrawn = withdrawn[0], withdrawn[1:]
            yield 'withdraw:%s' % '.'.join(str(_) for _ in r) + '/' + str(cdr)

        while announced:
            cdr, announced = announced[0], announced[1:]
            size = cdr_to_length(cdr)
            r = [0, 0, 0, 0]
            for index in range(size):
                r[index], announced = announced[0], announced[1:]
            yield 'announce:%s' % '.'.join(str(_) for _ in r) + '/' + str(cdr)

    def notification(self, header, body):
        yield 'notification:%d,%d' % (body[0], body[1]), bytestream(body)

    def announce(self, *args):
        flushed('    ', self.ip, self.port, ' '.join(str(_) for _ in args) if len(args) > 1 else args[0])

    def check_signal(self):
        if self.messages and self.messages[0].startswith('signal:'):
            name = self.messages.pop(0).split(':')[-1]
            self.signal(os.getppid(), name)

    def setup(self, ip, port, messages, options):
        self.ip = ip
        self.port = port
        self.options = options
        self.handle_read = self.handle_open
        self.sequence = {}
        self.raw = False
        for rule in messages:
            sequence, announcement = rule.split(':', 1)
            if announcement.startswith('raw:'):
                self.raw = True
                announcement = ''.join(announcement[4:].replace(':', ''))
            self.sequence.setdefault(sequence, []).append(announcement)
        self.update_sequence()
        return self

    def update_sequence(self):
        if self.options['sink'] or self.options['echo']:
            self.messages = []
            return True
        keys = sorted(list(self.sequence))
        if keys:
            key = keys[0]
            self.messages = self.sequence[key]
            self.step = key
            del self.sequence[key]

            self.check_signal()
            # we had a list with only one signal
            if not self.messages:
                return self.update_sequence()
            return True
        return False

    def read_message(self):
        header = b''
        while len(header) != 19:
            try:
                left = 19 - len(header)
                header += self.recv(left)
                if left == 19 - len(header):  # ugly
                    # the TCP session is gone.
                    return None, None
            except socket.error as exc:
                if exc.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
                    continue
                raise exc

        length = unpack('!H', header[16:18])[0] - 19

        body = b''
        while len(body) != length:
            try:
                left = length - len(body)
                body += self.recv(left)
            except socket.error as exc:
                if exc.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
                    continue
                raise exc

        return bytearray(header), bytearray(body)

    def handle_open(self):
        # reply with a IBGP response with the same capability (just changing routerID)
        header, body = self.read_message()
        routerid = bytearray([body[8] + 1 & 0xFF])
        o = header + body[:8] + routerid + body[9:]

        if self.options['send-unknown-capability']:
            # hack capability 66 into the message

            content = b'loremipsum'
            cap66 = bytearray([66, len(content)]) + content
            param = bytearray([2, len(cap66)]) + cap66
            o = o[:17] + bytearray([o[17] + len(param)]) + o[18:28] + bytearray([o[28] + len(param)]) + o[29:] + param

        self.send(o)
        self.send(self.keepalive)

        if self.options['send-default-route']:
            self.send(
                bytearray(
                    [0xFF,] * 16
                    + [0x00, 0x31]
                    + [0x02,]
                    + [0x00, 0x00]
                    + [0x00, 0x15]
                    + []
                    + [0x40, 0x01, 0x01, 0x00]
                    + []
                    + [0x40, 0x02, 0x00]
                    + []
                    + [0x40, 0x03, 0x04, 0x7F, 0x00, 0x00, 0x01]
                    + []
                    + [0x40, 0x05, 0x04, 0x00, 0x00, 0x00, 0x64]
                    + [0x20, 0x00, 0x00, 0x00, 0x00]
                )
            )
            self.announce('sending default-route\n')

        self.handle_read = self.handle_keepalive

    def handle_keepalive(self):
        header, body = self.read_message()

        if header is None:
            self.announce('connection closed')
            self.close()
            if self.options['send-notification']:
                self.announce('successful')
                sys.exit(0)
            return

        if self.raw:

            def parser(self, header, body):
                if body:
                    yield bytestream(header + body)

        else:
            parser = self._decoder.get(self.kind(header), None)

        if self.options['sink']:
            self.announce(
                'received %d: %s'
                % (
                    self.counter,
                    '%s:%s:%s:%s'
                    % (bytestream(header[:16]), bytestream(header[16:18]), bytestream(header[18:]), bytestream(body)),
                )
            )
            self.send(self.keepalive)
            return

        if self.options['echo']:
            self.announce(
                'received %d: %s'
                % (
                    self.counter,
                    '%s:%s:%s:%s'
                    % (bytestream(header[:16]), bytestream(header[16:18]), bytestream(header[18:]), bytestream(body)),
                )
            )
            self.send(header + body)
            self.announce(
                'sent     %d: %s'
                % (
                    self.counter,
                    '%s:%s:%s:%s'
                    % (bytestream(header[:16]), bytestream(header[16:18]), bytestream(header[18:]), bytestream(body)),
                )
            )
            return

        if parser:
            for announcement in parser(self, header, body):
                self.send(self.keepalive)
                if announcement.startswith('eor:'):  # skip EOR
                    self.announce('skipping eor', announcement)
                    continue

                if announcement.startswith('mp:'):  # skip unparsed MP
                    self.announce('skipping multiprotocol :', dump(body))
                    continue

                self.counter += 1

                if announcement in self.messages:
                    self.messages.remove(announcement)
                    if self.raw:
                        self.announce(
                            'received %d (%1s%s):' % (self.counter, self.options['letter'], self.step),
                            '%s:%s:%s:%s'
                            % (announcement[:32], announcement[32:36], announcement[36:38], announcement[38:]),
                        )
                    else:
                        self.announce(
                            'received %d (%1s%s):' % (self.counter, self.options['letter'], self.step), announcement
                        )
                    self.check_signal()
                else:
                    if self.raw:
                        self.announce(
                            'received %d (%1s%s):' % (self.counter, self.options['letter'], self.step),
                            '%s:%s:%s:%s'
                            % (
                                bytestream(header[:16]),
                                bytestream(header[16:18]),
                                bytestream(header[18:]),
                                bytestream(body),
                            ),
                        )
                    else:
                        self.announce('received %d     :' % self.counter, announcement)

                    if len(self.messages) > 1:
                        self.announce('expected one of the following :')
                        for message in self.messages:
                            if message.startswith('F' * 32):
                                self.announce(
                                    '                 %s:%s:%s:%s'
                                    % (message[:32], message[32:36], message[36:38], message[38:])
                                )
                            else:
                                self.announce('                 %s' % message)
                    elif self.messages:
                        message = self.messages[0].upper()
                        if message.startswith('F' * 32):
                            self.announce('expected       : %s:%s:%s:%s' % (message[:32], message[32:36], message[36:38], message[38:]))
                        else:
                            self.announce('expected       : %s' % message)
                    else:
                        # can happen when the thread is still running
                        self.announce('extra data')
                        sys.exit(1)

                    sys.exit(1)

                if not self.messages:
                    if self.options['single-shot']:
                        self.announce('successful (partial test)')
                        sys.exit(0)

                    if not self.update_sequence():
                        if self.options['exit']:
                            self.announce('successful')
                            sys.exit(0)
        else:
            self.send(self.keepalive)

        if self.options['send-notification']:
            notification = b'closing session because we can'
            self.send(
                bytearray([0xFF,] * 16 + [0x00, 19 + 2 + len(notification)] + [0x03] + [0x06] + [0x00]) + notification
            )

    _decoder = {
        2: routes,
        3: notification,
    }


class BGPServer(asyncore.dispatcher):
    def announce(self, *args):
        flushed('    ' + ' '.join(str(_) for _ in args) if len(args) > 1 else args[0])

    def __init__(self, host, options):
        asyncore.dispatcher.__init__(self)

        if ':' in host:
            self.create_socket(socket.AF_INET6, socket.SOCK_STREAM)
        else:
            self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
        self.set_reuse_addr()
        self.bind((host, options['port']))
        self.listen(5)

        self.messages = {}

        self.options = {
            'send-unknown-capability': False,  # add an unknown capability to the open message
            'send-default-route': False,  # send a default route to the peer
            'send-notification': False,  # send notification messages to the backend
            'signal-SIGUSR1': 0,  # send SIGUSR1 after X seconds
            'single-shot': False,  # we can not test signal on python 2.6
            'sink': False,  # just accept whatever is sent
            'echo': False,  # just accept whatever is sent
        }
        self.options.update(options)

        for message in options['messages']:
            if message.strip() == 'option:open:send-unknown-capability':
                self.options['send-unknown-capability'] = True
                continue
            if message.strip() == 'option:update:send-default-route':
                self.options['send-default-route'] = True
                continue
            if message.strip() == 'option:notification:send-notification':
                self.options['send-notification'] = True
                continue
            if message.strip().startswith('option:SIGUSR1:'):

                def notify(delay, myself):
                    time.sleep(delay)
                    self.signal(myself)
                    time.sleep(10)

                # Python 2.6 can not perform this test as it misses the function
                if 'check_output' in dir(subprocess):
                    # thread.start_new_thread(notify,(int(message.split(':')[-1]),os.getpid()))
                    threading.Thread(target=notify, args=(int(message.split(':')[-1]), os.getpid()))
                else:
                    self.options['single-shot'] = True
                continue

            if message[0].isalpha():
                index, content = message[:1].upper(), message[1:]
            else:
                index, content = 'A', message
            self.messages.setdefault(index, []).append(content)

    def handle_accept(self):
        messages = None
        for number in range(ord('A'), ord('Z') + 1):
            letter = chr(number)
            if letter in self.messages:
                messages = self.messages[letter]
                del self.messages[letter]
                break

        if self.options['sink']:
            flushed('\nsink mode - send us whatever, we can take it ! :p\n')
            messages = []
        elif self.options['echo']:
            flushed('\necho mode - send us whatever, we can parrot it ! :p\n')
            messages = []
        elif not messages:
            self.announce('we used all the test data available, can not handle this new connection')
            sys.exit(1)
        else:
            flushed('using :\n   ', '\n    '.join(messages), '\n\nconversation:\n')

        self.options['exit'] = not len(self.messages.keys())
        self.options['letter'] = letter

        pair = self.accept()
        if pair is not None:
            sock, addr = pair
            handler = BGPHandler(sock).setup(*addr[:2], messages=messages, options=self.options)


def drop():
    uid = os.getuid()
    gid = os.getgid()

    if uid and gid:
        return

    for name in [
        'nobody',
    ]:
        try:
            user = pwd.getpwnam(name)
            nuid = int(user.pw_uid)
            ngid = int(user.pw_uid)
        except KeyError:
            pass

    if not gid:
        os.setgid(ngid)
    if not uid:
        os.setuid(nuid)


def main():
    port = os.environ.get('exabgp.tcp.port', os.environ.get('exabgp_tcp_port', '179'))

    if not port.isdigit() and port > 0 and port <= 65535 or len(sys.argv) <= 1:
        flushed('--sink   accept any BGP messages and reply with a keepalive')
        flushed('--echo   accept any BGP messages send it back to the emiter')
        flushed('--port <port>   port to bind to')
        flushed(
            'a list of expected route announcement/withdrawl in the format <number>:announce:<ipv4-route> <number>:withdraw:<ipv4-route> <number>:raw:<exabgp hex dump : separated>'
        )
        flushed('for example:', sys.argv[0], '1:announce:10.0.0.0/8 1:announce:192.0.2.0/24 2:withdraw:10.0.0.0/8 ')
        flushed('routes with the same <number> can arrive in any order')
        sys.exit(1)

    options = {'sink': False, 'echo': False, 'port': int(port), 'messages': []}

    for arg in sys.argv[1:]:
        if arg == '--sink':
            messages = []
            options['sink'] = True
            continue

        if arg == '--echo':
            messages = []
            options['echo'] = True
            continue

        if arg == '--port':
            args = sys.argv[1:] + [
                '',
            ]
            port = args[args.index('--port') + 1]
            if port.isdigit() and int(port) > 0:
                options['port'] = int(port)
                continue
            print('invalid port %s' % port)
            sys.exit(1)

        if arg == str(options['port']):
            continue

        try:
            with open(sys.argv[1]) as content:
                options['messages'] = [_.strip() for _ in content.readlines() if _.strip() and '#' not in _]
        except IOError:
            flushed('could not open file', sys.argv[1])
            sys.exit(1)

    try:
        BGPServer('127.0.0.1', options)
        try:
            BGPServer('::1', options)
        except:
            # does not work on travis-ci
            pass
        drop()
        asyncore.loop()
    except socket.error as exc:
        if exc.errno == errno.EACCES:
            flushed('failure: could not bind to port %s - most likely not run as root' % port)
        elif exc.errno == errno.EADDRINUSE:
            flushed('failure: could not bind to port %s - port already in use' % port)
        else:
            flushed('failure', str(exc))


if __name__ == '__main__':
    main()
