summaryrefslogtreecommitdiff
path: root/project2/proj2_s4498062/dns/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'project2/proj2_s4498062/dns/server.py')
-rw-r--r--project2/proj2_s4498062/dns/server.py75
1 files changed, 64 insertions, 11 deletions
diff --git a/project2/proj2_s4498062/dns/server.py b/project2/proj2_s4498062/dns/server.py
index f01043d..f830651 100644
--- a/project2/proj2_s4498062/dns/server.py
+++ b/project2/proj2_s4498062/dns/server.py
@@ -3,27 +3,70 @@
This module provides a recursive DNS server. You will have to implement this
server using the algorithm described in section 4.3.2 of RFC 1034.
"""
-
import re
+import socket
from threading import Thread
import dns.regexes as rgx
from dns.classes import Class
from dns.types import Type
+from dns.message import Header, Message
+from dns.resolver import Resolver
+from dns.resource import ResourceRecord, ARecordData, CNAMERecordData
class RequestHandler(Thread):
""" A handler for requests to the DNS server """
- def __init__(self):
+ def __init__(self, skt, ttl, data, addr, zone):
+ # pylint: disable=too-many-arguments
""" Initialize the handler thread """
super(RequestHandler, self).__init__()
self.daemon = True
+ self.skt = skt
+ self.ttl = ttl
+ self.data = data
+ self.addr = addr
+ self.zone = zone
def run(self):
""" Run the handler thread """
- # TODO: Handle DNS request
- pass
+ resolver = Resolver(True, self.ttl)
+
+ request = Message.from_bytes(self.data)
+ answs, adds, auths = [], [], []
+
+ for req in request.questions:
+ rrs = [
+ r for r in self.zone
+ if r.match(type_=req.qtype, class_=req.qclass, name=req.qname)]
+ if rrs != []:
+ auths += rrs
+ elif req.qtype in [Type.A, Type.CNAME] and req.qclass == Class.IN:
+ name, cnames, addrs = resolver.gethostbyname(req.qname)
+ if name != req.qname:
+ answs.append(ResourceRecord(
+ str(req.qname), Type.CNAME, Class.IN, self.ttl,
+ CNAMERecordData(str(name))))
+ # pylint: disable=bad-continuation
+ addrs = [ResourceRecord(
+ name, Type.A, Class.IN, self.ttl,
+ ARecordData(data))
+ for data in addrs]
+ cnames = [ResourceRecord(
+ name, Type.CNAME, Class.IN, self.ttl,
+ CNAMERecordData(data))
+ for data in cnames]
+ if req.qtype == Type.A:
+ answs += addrs + cnames
+ if req.qtype == Type.CNAME:
+ answs += cnames
+
+ header = Header(
+ request.header.ident, 0, 0, len(answs), len(auths), len(adds))
+ response = Message(header, None, answs, auths, adds)
+
+ self.skt.sendto(response.to_bytes(), self.addr)
class Server(object):
@@ -41,23 +84,32 @@ class Server(object):
self.ttl = ttl
self.port = port
self.done = False
- # TODO: create socket
+ self.zone = []
+
+ self.skt = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
def serve(self):
""" Start serving request """
- # TODO: start listening
+ self.skt.bind(('localhost', self.port))
while not self.done:
- # TODO: receive request and open handler
- pass
+ data, addr = self.skt.recvfrom(512)
+ reqh = RequestHandler(
+ self.skt, self.ttl, data, addr, zone=self.zone)
+ reqh.start()
def shutdown(self):
""" Shutdown the server """
+ self.skt.close()
self.done = True
- # TODO: shutdown socket
def parse_zone_file(self, fname):
+ """Parse a zone file
+
+ Will crash if the zone file has incorrect syntax.
+ """
with open(fname) as zonef:
zone = zonef.read()
+ ttl, class_ = 3600000, Class.IN
for match in re.finditer(rgx.ZONE_LINE_DOMAIN, zone, re.MULTILINE):
match = match.groups()
name = match[0]
@@ -66,5 +118,6 @@ class Server(object):
match[2] or match[3] or Class.to_string(class_))
type_ = Type.from_string(match[5])
data = match[6]
- print match
- print name, ttl, Class.to_string(class_), Type.to_string(type_), data
+
+ record = ResourceRecord(name, type_, class_, ttl, data)
+ self.zone.append(record)