summaryrefslogtreecommitdiff
path: root/project2/proj2_s4498062/dns/resolver.py
blob: af60686e75b4cecefd99bc5f85546b52da8298d4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
""" DNS Resolver

This module contains a class for resolving hostnames. You will have to
implement things in this module. This resolver will be both used by the DNS
client and the DNS server, but with a different list of servers.
"""

import re
import socket

from dns.cache import RecordCache
from dns.classes import Class
from dns.message import Message, Header, Question
from dns.types import Type
import dns.regexes as rgx


class Resolver(object):
    """ DNS resolver """

    ROOT_SERVERS = [
        '198.41.0.4',
        '192.228.79.201',
        '192.33.4.12',
        '199.7.91.13',
        '192.203.230.10',
        '192.5.5.241',
        '192.112.36.4',
        '198.97.190.53',
        '192.36.148.17',
        '192.58.128.30',
        '193.0.14.129',
        '199.7.83.42',
        '202.12.27.33'
        ]

    def __init__(self, nameservers, timeout, caching, ttl):
        """ Initialize the resolver

        Args:
            caching (bool): caching is enabled if True
            ttl (int): ttl of cache entries (if > 0)
        """
        self.nameservers = nameservers + self.ROOT_SERVERS
        self.timeout = timeout
        self.caching = caching
        self.ttl = ttl

        if self.caching:
            self.cache = RecordCache()

    def do_query(self, query, using):
        """Send a query to a list of name servers"""
        for hint in using:
            if re.match(rgx.IP, hint) is None:
                continue
            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            sock.settimeout(self.timeout)
            try:
                sock.sendto(query.to_bytes(), (hint, 53))
                data = sock.recv(512)
                response = Message.from_bytes(data)

                if self.caching:
                    for record in response.answers + \
                            response.authorities + \
                            response.additionals:
                        self.cache.add_record(record)

                yield response
            except socket.timeout:
                pass

    def try_hint(self, sub, dom, hints):
        """Helper for get_hints"""
        if sub == '':
            for hint in hints:
                yield hint
        else:
            for new_hint in self.get_hints(
                    sub, dom, hints, in_recursion=True):
                yield new_hint

    def get_hints(self, domain, parent='', using=None, in_recursion=False):
        """Get a list of nameservers for a domain"""
        if using is None:
            using = []

        if not in_recursion:
            using += self.nameservers

        print 'Trying', domain, parent, 'using', using

        domains = re.match(rgx.DOMAIN, domain)
        if domains is None:
            return
        sub, dom = domains.groups()
        if parent != '':
            dom += '.' + parent

        if self.caching:
            hints = self.cache.lookup(dom, Type.NS, Class.IN)
            for hint in self.try_hint(
                    sub, dom, [r.rdata.data for r in hints]):
                yield hint

        header = Header(0, 0, 1, 0, 0, 0)
        header.qr = 0
        header.opcode = 0
        header.rd = 0
        query = Message(header, [Question(dom, Type.NS, Class.IN)])

        for response in self.do_query(query, using):
            hints = [ip for _, [ip] in response.get_hints()]
            for hint in self.try_hint(sub, dom, hints):
                yield hint

    def gethostbyname(self, hostname):
        """ Translate a host name to IPv4 address.

        Args:
            hostname (str): the hostname to resolve

        Returns:
            (str, [str], [str]): (hostname, aliaslist, ipaddrlist)
        """
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.settimeout(self.timeout)

        if self.caching:
            addrs = list(self.cache.lookup(hostname, Type.A, Class.IN))
            cnames = list(self.cache.lookup(hostname, Type.CNAME, Class.IN))
            if addrs != []:
                return (
                    hostname,
                    [r.rdata.data for r in cnames],
                    [r.rdata.data for r in addrs])

        # Create and send query
        question = Question(hostname, Type.A, Class.IN)
        header = Header(9001, 0, 1, 0, 0, 0)
        header.qr = 0
        header.opcode = 0
        header.rd = 1
        query = Message(header, [question])

        for response in self.do_query(query, self.get_hints(hostname)):
            aliases = response.get_aliases()
            addresses = response.get_addresses()

            return hostname, aliases, addresses

        return hostname, [], []