summaryrefslogtreecommitdiff
path: root/project2/proj2_s4498062/dns/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'project2/proj2_s4498062/dns/server.py')
-rw-r--r--project2/proj2_s4498062/dns/server.py133
1 files changed, 133 insertions, 0 deletions
diff --git a/project2/proj2_s4498062/dns/server.py b/project2/proj2_s4498062/dns/server.py
new file mode 100644
index 0000000..10cad8b
--- /dev/null
+++ b/project2/proj2_s4498062/dns/server.py
@@ -0,0 +1,133 @@
+""" A recursive DNS server
+
+This module provides a recursive DNS server. You will have to implement this
+server using the algorithm described in section 4.3.2 of RFC 1034.
+"""
+import re
+import socket
+from threading import Thread
+
+import dns.regexes as rgx
+from dns.classes import Class
+from dns.types import Type
+from dns.message import Header, Message
+from dns.resolver import Resolver
+from dns.resource import \
+ ResourceRecord, ARecordData, NSRecordData, CNAMERecordData
+
+
+class RequestHandler(Thread):
+ """ A handler for requests to the DNS server """
+
+ def __init__(self, skt, ttl, data, addr, zone):
+ # pylint: disable=too-many-arguments
+ """ Initialize the handler thread """
+ super(RequestHandler, self).__init__()
+ self.daemon = True
+ self.skt = skt
+ self.ttl = ttl
+ self.data = data
+ self.addr = addr
+ self.zone = zone
+
+ def run(self):
+ """ Run the handler thread """
+ resolver = Resolver(True, self.ttl)
+
+ request = Message.from_bytes(self.data)
+ answs, adds, auths = [], [], []
+
+ for req in request.questions:
+ rrs = [
+ r for r in self.zone
+ if r.match(type_=req.qtype, class_=req.qclass, name=req.qname)]
+ if rrs != []:
+ auths += rrs
+ elif req.qtype in [Type.A, Type.CNAME] and req.qclass == Class.IN:
+ name, cnames, addrs = resolver.gethostbyname(req.qname)
+ if name != req.qname:
+ answs.append(ResourceRecord(
+ str(req.qname), Type.CNAME, Class.IN, self.ttl,
+ CNAMERecordData(str(name))))
+ # pylint: disable=bad-continuation
+ addrs = [ResourceRecord(
+ name, Type.A, Class.IN, self.ttl,
+ ARecordData(data))
+ for data in addrs]
+ cnames = [ResourceRecord(
+ name, Type.CNAME, Class.IN, self.ttl,
+ CNAMERecordData(data))
+ for data in cnames]
+ if req.qtype == Type.A:
+ answs += addrs + cnames
+ if req.qtype == Type.CNAME:
+ answs += cnames
+
+ header = Header(
+ request.header.ident, 0, 0, len(answs), len(auths), len(adds))
+ response = Message(header, None, answs, auths, adds)
+
+ self.skt.sendto(response.to_bytes(), self.addr)
+
+
+class Server(object):
+ """ A recursive DNS server """
+
+ def __init__(self, port, caching, ttl):
+ """ Initialize the server
+
+ Args:
+ port (int): port that server is listening on
+ caching (bool): server uses resolver with caching if true
+ ttl (int): ttl for records (if > 0) of cache
+ """
+ self.caching = caching
+ self.ttl = ttl
+ self.port = port
+ self.done = False
+ self.zone = []
+
+ self.skt = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+
+ def serve(self):
+ """ Start serving request """
+ self.skt.bind(('localhost', self.port))
+ while not self.done:
+ data, addr = self.skt.recvfrom(512)
+ reqh = RequestHandler(
+ self.skt, self.ttl, data, addr, zone=self.zone)
+ reqh.start()
+
+ def shutdown(self):
+ """ Shutdown the server """
+ self.skt.close()
+ self.done = True
+
+ def parse_zone_file(self, fname):
+ """Parse a zone file
+
+ Will crash if the zone file has incorrect syntax.
+ """
+ with open(fname) as zonef:
+ zone = zonef.read()
+ ttl, class_ = 3600000, Class.IN
+ for match in re.finditer(rgx.ZONE_LINE_DOMAIN, zone, re.MULTILINE):
+ match = match.groups()
+ name = match[0][:-1]
+ ttl = int(match[1] or match[4] or ttl * 1000) / 1000
+ class_ = Class.from_string(
+ match[2] or match[3] or Class.to_string(class_))
+ type_ = Type.from_string(match[5])
+ data = match[6]
+
+ if type_ == Type.A:
+ cls = ARecordData
+ elif type_ == Type.NS:
+ cls = NSRecordData
+ elif type_ == Type.CNAME:
+ cls = CNAMERecordData
+ else:
+ continue
+
+ record = ResourceRecord(name, type_, class_, ttl, cls(data))
+ self.zone.append(record)