#!/usr/bin/python3
#
# use avahi to find a _apt_proxy._tcp provider and return
# a http proxy string suitable for apt

import asyncio
import functools
import socket
import sys
import time

from subprocess import PIPE, Popen
from syslog import syslog, LOG_INFO, LOG_USER

DEFAULT_CONNECT_TIMEOUT_SEC = 2


def DEBUG(msg):
    if "--debug" in sys.argv:
        sys.stderr.write(msg + "\n")


def get_avahi_discover_timeout():
    APT_AVAHI_TIMEOUT_VAR = "APT::Avahi-Discover::Timeout"
    p = Popen(
        ["/usr/bin/apt-config", "shell", "TIMEOUT", APT_AVAHI_TIMEOUT_VAR],
        stdout=PIPE, universal_newlines=True)
    stdout, stderr = p.communicate()
    if not stdout:
        DEBUG(
            "no timeout set, using default '{}'".format(
                DEFAULT_CONNECT_TIMEOUT_SEC))
        return DEFAULT_CONNECT_TIMEOUT_SEC
    if not stdout.startswith("TIMEOUT="):
        raise ValueError("got unexpected apt-config output: '{}'".format(
            stdout))
    varname, sep, value = stdout.strip().partition("=")
    timeout = int(value.strip("'"))
    DEBUG("using timeout: '{}'".format(timeout))
    return timeout


@functools.total_ordering
class AptAvahiClient:
    def __init__(self, addr):
        self.time_to_connect = sys.maxsize
        self.address = addr

    async def time_connect(self):
        try:
            _time_init = time.time()
            reader, writer = await asyncio.open_connection(
                self.address[0], self.address[1])
            self.time_to_connect = time.time() - _time_init
            writer.close()
        except Exception:
            pass

    def __eq__(self, other):
        return self.time_to_connect == other.time_to_connect

    def __lt__(self, other):
        return self.time_to_connect < other.time_to_connect

    def __repr__(self):
        return "<{}> {}: {}".format(
            self.__class__.__name__, self.address, self.time_to_connect)

    def log(self, message):
        syslog((LOG_INFO | LOG_USER), '%s\n' % str(message))

    def log_info(self, message, type='info'):
        if type not in self.ignore_log_types:
            self.log('%s: %s' % (type, message))


def is_ipv6(a):
    return ':' in a


def is_linklocal(addr):
    # Link-local should start with fe80 and six null bytes
    return addr.startswith("fe80::")


def get_proxy_host_port_from_avahi():
    service = '_apt_proxy._tcp'

    # Obtain all of the services addresses from avahi, pulling the IPv6
    # addresses to the top.
    addr4 = []
    addr6 = []
    p = Popen(['avahi-browse', '-kprtf', service], stdout=PIPE,
              universal_newlines=True)
    DEBUG("avahi-browse output:")
    for line in p.stdout:
        DEBUG(" '{}'".format(line))
        if line.startswith('='):
            tokens = line.split(';')
            addr = tokens[7]
            port = int(tokens[8])
            if is_ipv6(addr):
                # We need to skip ipv6 link-local addresses since
                # APT can't use them
                if not is_linklocal(addr):
                    addr6.append((addr, port))
            else:
                addr4.append((addr, port))

    # Run through the offered addresses and see if we we have a bound local
    # address for it.
    addrs = []
    for (ip, port) in addr6 + addr4:
        try:
            res = socket.getaddrinfo(ip, port, 0, 0, 0, socket.AI_ADDRCONFIG)
            if res:
                addrs.append((ip, port))
        except socket.gaierror:
            pass
    if not addrs:
        return None

    # sort by answering speed
    hosts = []
    for addr in addrs:
        hosts.append(AptAvahiClient(addr))
    # 2s timeout, arbitray
    timeout = get_avahi_discover_timeout()

    async def gather_hosts():
        await asyncio.gather(*[h.time_connect() for h in hosts])

    try:
        asyncio.run(asyncio.wait_for(gather_hosts(), timeout=timeout))
    except asyncio.TimeoutError:
        pass

    DEBUG("sorted hosts: '{}'".format(sorted(hosts)))

    # No host wanted to connect
    if (all(h.time_to_connect == sys.maxsize for h in hosts)):
        return None

    fastest_host = sorted(hosts)[0]
    fastest_address = fastest_host.address
    return fastest_address


if __name__ == "__main__":
    # Dump the approved address out in an appropriate format.
    address = get_proxy_host_port_from_avahi()
    if address:
        (ip, port) = address
        if is_ipv6(ip):
            print("http://[{}]:{}/".format(ip, port))
        else:
            print("http://{}:{}/".format(ip, port))
