1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
|
""" A recursive DNS server
This module provides a recursive DNS server. You will have to implement this
server using the algorithm described in section 4.3.2 of RFC 1034.
"""
import re
import socket
from threading import Thread
import dns.regexes as rgx
from dns.classes import Class
from dns.types import Type
from dns.message import Header, Message
from dns.resolver import Resolver
from dns.resource import \
ResourceRecord, ARecordData, NSRecordData, CNAMERecordData
class RequestHandler(Thread):
""" A handler for requests to the DNS server """
def __init__(self, skt, ttl, data, addr, zone):
# pylint: disable=too-many-arguments
""" Initialize the handler thread """
super(RequestHandler, self).__init__()
self.daemon = True
self.skt = skt
self.ttl = ttl
self.data = data
self.addr = addr
self.zone = zone
def run(self):
""" Run the handler thread """
resolver = Resolver(True, self.ttl)
request = Message.from_bytes(self.data)
answs, adds, auths = [], [], []
for req in request.questions:
rrs = [
r for r in self.zone
if r.match(type_=req.qtype, class_=req.qclass, name=req.qname)]
if rrs != []:
auths += rrs
elif req.qtype in [Type.A, Type.CNAME] and req.qclass == Class.IN:
name, cnames, addrs = resolver.gethostbyname(req.qname)
if name != req.qname:
answs.append(ResourceRecord(
str(req.qname), Type.CNAME, Class.IN, self.ttl,
CNAMERecordData(str(name))))
# pylint: disable=bad-continuation
addrs = [ResourceRecord(
name, Type.A, Class.IN, self.ttl,
ARecordData(data))
for data in addrs]
cnames = [ResourceRecord(
name, Type.CNAME, Class.IN, self.ttl,
CNAMERecordData(data))
for data in cnames]
if req.qtype == Type.A:
answs += addrs + cnames
if req.qtype == Type.CNAME:
answs += cnames
header = Header(
request.header.ident, 0, 0, len(answs), len(auths), len(adds))
response = Message(header, None, answs, auths, adds)
self.skt.sendto(response.to_bytes(), self.addr)
class Server(object):
""" A recursive DNS server """
def __init__(self, port, caching, ttl):
""" Initialize the server
Args:
port (int): port that server is listening on
caching (bool): server uses resolver with caching if true
ttl (int): ttl for records (if > 0) of cache
"""
self.caching = caching
self.ttl = ttl
self.port = port
self.done = False
self.zone = []
self.skt = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
def serve(self):
""" Start serving request """
self.skt.bind(('localhost', self.port))
while not self.done:
data, addr = self.skt.recvfrom(512)
reqh = RequestHandler(
self.skt, self.ttl, data, addr, zone=self.zone)
reqh.start()
def shutdown(self):
""" Shutdown the server """
self.skt.close()
self.done = True
def parse_zone_file(self, fname):
"""Parse a zone file
Will crash if the zone file has incorrect syntax.
"""
with open(fname) as zonef:
zone = zonef.read()
ttl, class_ = 3600000, Class.IN
for match in re.finditer(rgx.ZONE_LINE_DOMAIN, zone, re.MULTILINE):
match = match.groups()
name = match[0][:-1]
ttl = int(match[1] or match[4] or ttl * 1000) / 1000
class_ = Class.from_string(
match[2] or match[3] or Class.to_string(class_))
type_ = Type.from_string(match[5])
data = match[6]
if type_ == Type.A:
cls = ARecordData
elif type_ == Type.NS:
cls = NSRecordData
elif type_ == Type.CNAME:
cls = CNAMERecordData
else:
continue
record = ResourceRecord(name, type_, class_, ttl, cls(data))
self.zone.append(record)
|