#!/usr/bin/env python2

""" DNS Resolver

This module contains a class for resolving hostnames. You will have to implement
things in this module. This resolver will be both used by the DNS client and the
DNS server, but with a different list of servers.
"""

from random import randint
import socket

from dns.classes import Class
from dns.types import Type

from dns.cache import RecordCache
from dns.message import Message, Question, Header
import dns.rcodes


class Resolver(object):
    """ DNS resolver """

    ROOT_SERVERS = [
        '198.41.0.4',
        '192.228.79.201',
        '192.33.4.12',
        '199.7.91.13',
        '192.203.230.10',
        '192.5.5.241',
        '192.112.36.4',
        '198.97.190.53',
        '192.36.148.17',
        '192.58.128.30',
        '193.0.14.129',
        '199.7.83.42',
        '202.12.27.33'
        ]

    def __init__(self, caching, ttl, timeout=3):
        """ Initialize the resolver

        Args:
            caching (bool): caching is enabled if True
            ttl (int): ttl of cache entries (if > 0)
        """
        self.caching = caching
        self.ttl = ttl
        self.timeout = timeout

        if self.caching:
            self.cache = RecordCache()

    def do_query(self, hint, hostname, type_, class_=Class.IN, caching=True):
        """Do a query to a hint"""
        if self.caching and caching:
            records = self.cache.lookup(hostname, type_, class_)
            if records != []:
                return records

        ident = randint(0, 65535)
        header = Header(ident, 0, 1, 0, 0, 0)
        header.qr = 0
        header.opcode = 0
        header.rd = 1
        req = Message(header, [Question(hostname, type_, class_)])

        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.settimeout(self.timeout)
        sock.sendto(req.to_bytes(), (hint, 53))

        try:
            data = sock.recv(512)
            resp = Message.from_bytes(data)
            if resp.header.ident == ident:
                if self.caching and caching:
                    self.cache.add_records_from(resp)
                return resp.answers + resp.authorities + resp.additionals
        except socket.timeout:
            pass
        return []

    def do_query_to_multiple(
            self, hints, hostname, type_, class_=Class.IN, caching=True):
        """Do a query to multiple hints, return the remaining hints"""
        while hints != []:
            hint = hints.pop()
            response = self.do_query(hint, hostname, type_, class_, caching)
            if response is not None:
                return hints, response
        return [], []

    def gethostbyname(self, hostname):
        """ Translate a host name to IPv4 address.

        Currently this method contains an example. You will have to replace
        this example with example with the algorithm described in section
        5.3.3 in RFC 1034.

        Args:
            hostname (str): the hostname to resolve

        Returns:
            (str, [str], [str]): (hostname, aliaslist, ipaddrlist)
        """
        domains = hostname.split('.')
        hints = self.ROOT_SERVERS

        if domains == []:
            return hostname, [], []

        domain = domains.pop(-1)
        aliases = []
        while hints != []:
            hints, info = self.do_query_to_multiple(hints, domain, Type.A)

            aliases += [
                r.rdata.data for r in info
                if r.match(type_=Type.CNAME, class_=Class.IN, name=domain)]

            # Case 1: answer
            ips = [
                r.rdata.data for r in info
                if r.match(type_=Type.A, class_=Class.IN, name=domain)]
            if ips != []:
                return hostname, aliases, ips

            # Case 2: name servers for this domain
            auths = [
                r.rdata.data for r in info
                if r.match(type_=Type.NS, class_=Class.IN, name=domain)]
            ips = [
                add.rdata.data for ns in auths for add in info
                if add.match(name=ns, type_=Type.A)]
            if ips != []:
                hints += ips
                if domain != hostname:
                    domain = domains.pop(-1) + '.' + domain
                continue
            if auths != []:
                auths = [h for a in auths for h in self.gethostbyname(a)[2]]
                hints += auths
                if domain != hostname:
                    domain = domains.pop(-1) + '.' + domain
                continue

            # Case 3: name servers for the same domain
            parent = '.'.join(domain.split('.')[1:])
            auths = [
                r.rdata.data for r in info
                if r.match(type_=Type.NS, class_=Class.IN, name=parent)]
            ips = [
                add.rdata.data for ns in auths for add in info
                if add.match(name=ns, type_=Type.A)]
            if ips != []:
                hints += ips
                continue
            if auths != []:
                auths = [h for r in auths for h in self.gethostbyname(r)[2]]
                hints += auths
                continue

            # Case 4: aliases
            for alias in aliases:
                _, extra_aliases, alias_addresses = self.gethostbyname(alias)
                if alias_addresses != []:
                    return hostname, aliases + extra_aliases, alias_addresses

        return hostname, aliases, []