diff options
Diffstat (limited to 'project2/proj2_s4498062/dns/server.py')
-rw-r--r-- | project2/proj2_s4498062/dns/server.py | 75 |
1 files changed, 64 insertions, 11 deletions
diff --git a/project2/proj2_s4498062/dns/server.py b/project2/proj2_s4498062/dns/server.py index f01043d..f830651 100644 --- a/project2/proj2_s4498062/dns/server.py +++ b/project2/proj2_s4498062/dns/server.py @@ -3,27 +3,70 @@ This module provides a recursive DNS server. You will have to implement this server using the algorithm described in section 4.3.2 of RFC 1034. """ - import re +import socket from threading import Thread import dns.regexes as rgx from dns.classes import Class from dns.types import Type +from dns.message import Header, Message +from dns.resolver import Resolver +from dns.resource import ResourceRecord, ARecordData, CNAMERecordData class RequestHandler(Thread): """ A handler for requests to the DNS server """ - def __init__(self): + def __init__(self, skt, ttl, data, addr, zone): + # pylint: disable=too-many-arguments """ Initialize the handler thread """ super(RequestHandler, self).__init__() self.daemon = True + self.skt = skt + self.ttl = ttl + self.data = data + self.addr = addr + self.zone = zone def run(self): """ Run the handler thread """ - # TODO: Handle DNS request - pass + resolver = Resolver(True, self.ttl) + + request = Message.from_bytes(self.data) + answs, adds, auths = [], [], [] + + for req in request.questions: + rrs = [ + r for r in self.zone + if r.match(type_=req.qtype, class_=req.qclass, name=req.qname)] + if rrs != []: + auths += rrs + elif req.qtype in [Type.A, Type.CNAME] and req.qclass == Class.IN: + name, cnames, addrs = resolver.gethostbyname(req.qname) + if name != req.qname: + answs.append(ResourceRecord( + str(req.qname), Type.CNAME, Class.IN, self.ttl, + CNAMERecordData(str(name)))) + # pylint: disable=bad-continuation + addrs = [ResourceRecord( + name, Type.A, Class.IN, self.ttl, + ARecordData(data)) + for data in addrs] + cnames = [ResourceRecord( + name, Type.CNAME, Class.IN, self.ttl, + CNAMERecordData(data)) + for data in cnames] + if req.qtype == Type.A: + answs += addrs + cnames + if req.qtype == Type.CNAME: + answs += cnames + + header = Header( + request.header.ident, 0, 0, len(answs), len(auths), len(adds)) + response = Message(header, None, answs, auths, adds) + + self.skt.sendto(response.to_bytes(), self.addr) class Server(object): @@ -41,23 +84,32 @@ class Server(object): self.ttl = ttl self.port = port self.done = False - # TODO: create socket + self.zone = [] + + self.skt = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) def serve(self): """ Start serving request """ - # TODO: start listening + self.skt.bind(('localhost', self.port)) while not self.done: - # TODO: receive request and open handler - pass + data, addr = self.skt.recvfrom(512) + reqh = RequestHandler( + self.skt, self.ttl, data, addr, zone=self.zone) + reqh.start() def shutdown(self): """ Shutdown the server """ + self.skt.close() self.done = True - # TODO: shutdown socket def parse_zone_file(self, fname): + """Parse a zone file + + Will crash if the zone file has incorrect syntax. + """ with open(fname) as zonef: zone = zonef.read() + ttl, class_ = 3600000, Class.IN for match in re.finditer(rgx.ZONE_LINE_DOMAIN, zone, re.MULTILINE): match = match.groups() name = match[0] @@ -66,5 +118,6 @@ class Server(object): match[2] or match[3] or Class.to_string(class_)) type_ = Type.from_string(match[5]) data = match[6] - print match - print name, ttl, Class.to_string(class_), Type.to_string(type_), data + + record = ResourceRecord(name, type_, class_, ttl, data) + self.zone.append(record) |