summaryrefslogtreecommitdiff
path: root/project2/proj2_s4498062/dns/resolver.py
blob: 8bd7bb956ae5998092fd59c04dd520fd09bd66c5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#!/usr/bin/env python2

""" DNS Resolver

This module contains a class for resolving hostnames. You will have to implement
things in this module. This resolver will be both used by the DNS client and the
DNS server, but with a different list of servers.
"""

from random import randint
import socket

from dns.classes import Class
from dns.types import Type

from dns.cache import RecordCache
from dns.message import Message, Question, Header


class Resolver(object):
    """ DNS resolver """

    ROOT_SERVERS = [
        '198.41.0.4',
        '192.228.79.201',
        '192.33.4.12',
        '199.7.91.13',
        '192.203.230.10',
        '192.5.5.241',
        '192.112.36.4',
        '198.97.190.53',
        '192.36.148.17',
        '192.58.128.30',
        '193.0.14.129',
        '199.7.83.42',
        '202.12.27.33'
        ]

    def __init__(self, caching, ttl, timeout=3):
        """ Initialize the resolver

        Args:
            caching (bool): caching is enabled if True
            ttl (int): ttl of cache entries (if > 0)
        """
        self.caching = caching
        self.ttl = ttl
        self.timeout = timeout

        if self.caching:
            self.cache = RecordCache(ttl)

    def do_query(self, hint, hostname, type_, class_=Class.IN, caching=True):
        """Do a query to a hint"""
        if self.caching and caching:
            records = self.cache.lookup(hostname, type_, class_)
            if records != []:
                return records

        ident = randint(0, 65535)
        header = Header(ident, 0, 1, 0, 0, 0)
        header.qr = 0
        header.opcode = 0
        header.rd = 1
        req = Message(header, [Question(hostname, type_, class_)])

        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.settimeout(self.timeout)
        sock.sendto(req.to_bytes(), (hint, 53))

        try:
            data = sock.recv(512)
            resp = Message.from_bytes(data)
            if resp.header.ident == ident:
                if self.caching and caching:
                    self.cache.add_records_from(resp)
                return resp.answers + resp.authorities + resp.additionals
        except socket.timeout:
            pass
        return []

    def do_query_to_multiple(
            self, hints, hostname, type_, class_=Class.IN, caching=True):
        """Do a query to multiple hints, return the remaining hints"""
        seen = []
        while hints != []:
            hint = hints.pop()
            seen.append(hint)
            response = self.do_query(hint, hostname, type_, class_, caching)
            if response is not None:
                return seen, hints, response
        return seen, [], []

    def gethostbyname(self, hostname):
        """ Translate a host name to IPv4 address.

        Args:
            hostname (str): the hostname to resolve

        Returns:
            (str, [str], [str]): (hostname, aliaslist, ipaddrlist)
        """
        if self.caching:
            addrs = self.cache.lookup(hostname, Type.A, Class.IN)
            cnames = self.cache.lookup(hostname, Type.CNAME, Class.IN)
            if addrs != []:
                return hostname, \
                    [r.rdata.data for r in cnames], \
                    [r.rdata.data for r in addrs]
            for cname in cnames:
                cname, aliases, addrs = self.gethostbyname(cname.rdata.data)
                if addrs != []:
                    return str(cname), list(set(aliases)), list(set(addrs))

        if hostname == '':
            return hostname, [], []

        seen = []
        hints = self.ROOT_SERVERS[:]
        aliases = []
        while hints != []:
            used, hints, info = self.do_query_to_multiple(
                hints, hostname, Type.A)
            seen += used

            aliases += [
                r.rdata.data for r in info
                if r.match(type_=Type.CNAME, class_=Class.IN, name=hostname)]

            # Case 1: answer
            ips = [
                r.rdata.data for r in info
                if r.match(type_=Type.A, class_=Class.IN, name=hostname)]
            if ips != []:
                return hostname, aliases, ips

            # Case 2: name servers
            auths = [
                r.rdata.data for r in info
                if r.match(type_=Type.NS, class_=Class.IN)]
            ips = [
                add.rdata.data for ns in auths for add in info
                if add.match(name=ns, type_=Type.A)]
            if ips != []:
                hints += [ip for ip in ips if ip not in seen]
                continue
            if auths != []:
                auths = [h for a in auths for h in self.gethostbyname(a)[2]]
                hints += [auth for auth in auths if auth not in seen]
                continue

            # Case 3: aliases
            for alias in aliases:
                _, extra_aliases, alias_addresses = self.gethostbyname(alias)
                if alias_addresses != []:
                    return hostname, aliases + extra_aliases, alias_addresses

        return hostname, aliases, []