summaryrefslogtreecommitdiff
path: root/project2/proj2_s4498062/dns/server.py
blob: f830651f3fd16db0858691f8bdd9f742f48d4999 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
""" A recursive DNS server

This module provides a recursive DNS server. You will have to implement this
server using the algorithm described in section 4.3.2 of RFC 1034.
"""
import re
import socket
from threading import Thread

import dns.regexes as rgx
from dns.classes import Class
from dns.types import Type
from dns.message import Header, Message
from dns.resolver import Resolver
from dns.resource import ResourceRecord, ARecordData, CNAMERecordData


class RequestHandler(Thread):
    """ A handler for requests to the DNS server """

    def __init__(self, skt, ttl, data, addr, zone):
        # pylint: disable=too-many-arguments
        """ Initialize the handler thread """
        super(RequestHandler, self).__init__()
        self.daemon = True
        self.skt = skt
        self.ttl = ttl
        self.data = data
        self.addr = addr
        self.zone = zone

    def run(self):
        """ Run the handler thread """
        resolver = Resolver(True, self.ttl)

        request = Message.from_bytes(self.data)
        answs, adds, auths = [], [], []

        for req in request.questions:
            rrs = [
                r for r in self.zone
                if r.match(type_=req.qtype, class_=req.qclass, name=req.qname)]
            if rrs != []:
                auths += rrs
            elif req.qtype in [Type.A, Type.CNAME] and req.qclass == Class.IN:
                name, cnames, addrs = resolver.gethostbyname(req.qname)
                if name != req.qname:
                    answs.append(ResourceRecord(
                        str(req.qname), Type.CNAME, Class.IN, self.ttl,
                        CNAMERecordData(str(name))))
                # pylint: disable=bad-continuation
                addrs = [ResourceRecord(
                    name, Type.A, Class.IN, self.ttl,
                    ARecordData(data))
                    for data in addrs]
                cnames = [ResourceRecord(
                    name, Type.CNAME, Class.IN, self.ttl,
                    CNAMERecordData(data))
                    for data in cnames]
                if req.qtype == Type.A:
                    answs += addrs + cnames
                if req.qtype == Type.CNAME:
                    answs += cnames

        header = Header(
            request.header.ident, 0, 0, len(answs), len(auths), len(adds))
        response = Message(header, None, answs, auths, adds)

        self.skt.sendto(response.to_bytes(), self.addr)


class Server(object):
    """ A recursive DNS server """

    def __init__(self, port, caching, ttl):
        """ Initialize the server

        Args:
            port (int): port that server is listening on
            caching (bool): server uses resolver with caching if true
            ttl (int): ttl for records (if > 0) of cache
        """
        self.caching = caching
        self.ttl = ttl
        self.port = port
        self.done = False
        self.zone = []

        self.skt = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)

    def serve(self):
        """ Start serving request """
        self.skt.bind(('localhost', self.port))
        while not self.done:
            data, addr = self.skt.recvfrom(512)
            reqh = RequestHandler(
                self.skt, self.ttl, data, addr, zone=self.zone)
            reqh.start()

    def shutdown(self):
        """ Shutdown the server """
        self.skt.close()
        self.done = True

    def parse_zone_file(self, fname):
        """Parse a zone file

        Will crash if the zone file has incorrect syntax.
        """
        with open(fname) as zonef:
            zone = zonef.read()
            ttl, class_ = 3600000, Class.IN
            for match in re.finditer(rgx.ZONE_LINE_DOMAIN, zone, re.MULTILINE):
                match = match.groups()
                name = match[0]
                ttl = int(match[1] or match[4] or ttl)
                class_ = Class.from_string(
                    match[2] or match[3] or Class.to_string(class_))
                type_ = Type.from_string(match[5])
                data = match[6]

                record = ResourceRecord(name, type_, class_, ttl, data)
                self.zone.append(record)