diff options
Diffstat (limited to 'project2/proj2_s4498062/dns/server.py')
-rw-r--r-- | project2/proj2_s4498062/dns/server.py | 133 |
1 files changed, 133 insertions, 0 deletions
diff --git a/project2/proj2_s4498062/dns/server.py b/project2/proj2_s4498062/dns/server.py new file mode 100644 index 0000000..10cad8b --- /dev/null +++ b/project2/proj2_s4498062/dns/server.py @@ -0,0 +1,133 @@ +""" A recursive DNS server + +This module provides a recursive DNS server. You will have to implement this +server using the algorithm described in section 4.3.2 of RFC 1034. +""" +import re +import socket +from threading import Thread + +import dns.regexes as rgx +from dns.classes import Class +from dns.types import Type +from dns.message import Header, Message +from dns.resolver import Resolver +from dns.resource import \ + ResourceRecord, ARecordData, NSRecordData, CNAMERecordData + + +class RequestHandler(Thread): + """ A handler for requests to the DNS server """ + + def __init__(self, skt, ttl, data, addr, zone): + # pylint: disable=too-many-arguments + """ Initialize the handler thread """ + super(RequestHandler, self).__init__() + self.daemon = True + self.skt = skt + self.ttl = ttl + self.data = data + self.addr = addr + self.zone = zone + + def run(self): + """ Run the handler thread """ + resolver = Resolver(True, self.ttl) + + request = Message.from_bytes(self.data) + answs, adds, auths = [], [], [] + + for req in request.questions: + rrs = [ + r for r in self.zone + if r.match(type_=req.qtype, class_=req.qclass, name=req.qname)] + if rrs != []: + auths += rrs + elif req.qtype in [Type.A, Type.CNAME] and req.qclass == Class.IN: + name, cnames, addrs = resolver.gethostbyname(req.qname) + if name != req.qname: + answs.append(ResourceRecord( + str(req.qname), Type.CNAME, Class.IN, self.ttl, + CNAMERecordData(str(name)))) + # pylint: disable=bad-continuation + addrs = [ResourceRecord( + name, Type.A, Class.IN, self.ttl, + ARecordData(data)) + for data in addrs] + cnames = [ResourceRecord( + name, Type.CNAME, Class.IN, self.ttl, + CNAMERecordData(data)) + for data in cnames] + if req.qtype == Type.A: + answs += addrs + cnames + if req.qtype == Type.CNAME: + answs += cnames + + header = Header( + request.header.ident, 0, 0, len(answs), len(auths), len(adds)) + response = Message(header, None, answs, auths, adds) + + self.skt.sendto(response.to_bytes(), self.addr) + + +class Server(object): + """ A recursive DNS server """ + + def __init__(self, port, caching, ttl): + """ Initialize the server + + Args: + port (int): port that server is listening on + caching (bool): server uses resolver with caching if true + ttl (int): ttl for records (if > 0) of cache + """ + self.caching = caching + self.ttl = ttl + self.port = port + self.done = False + self.zone = [] + + self.skt = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + + def serve(self): + """ Start serving request """ + self.skt.bind(('localhost', self.port)) + while not self.done: + data, addr = self.skt.recvfrom(512) + reqh = RequestHandler( + self.skt, self.ttl, data, addr, zone=self.zone) + reqh.start() + + def shutdown(self): + """ Shutdown the server """ + self.skt.close() + self.done = True + + def parse_zone_file(self, fname): + """Parse a zone file + + Will crash if the zone file has incorrect syntax. + """ + with open(fname) as zonef: + zone = zonef.read() + ttl, class_ = 3600000, Class.IN + for match in re.finditer(rgx.ZONE_LINE_DOMAIN, zone, re.MULTILINE): + match = match.groups() + name = match[0][:-1] + ttl = int(match[1] or match[4] or ttl * 1000) / 1000 + class_ = Class.from_string( + match[2] or match[3] or Class.to_string(class_)) + type_ = Type.from_string(match[5]) + data = match[6] + + if type_ == Type.A: + cls = ARecordData + elif type_ == Type.NS: + cls = NSRecordData + elif type_ == Type.CNAME: + cls = CNAMERecordData + else: + continue + + record = ResourceRecord(name, type_, class_, ttl, cls(data)) + self.zone.append(record) |