""" A recursive DNS server This module provides a recursive DNS server. You will have to implement this server using the algorithm described in section 4.3.2 of RFC 1034. """ import re import socket from threading import Thread import dns.regexes as rgx from dns.classes import Class from dns.types import Type from dns.message import Header, Message from dns.resolver import Resolver from dns.resource import \ ResourceRecord, ARecordData, NSRecordData, CNAMERecordData class RequestHandler(Thread): """ A handler for requests to the DNS server """ def __init__(self, skt, ttl, data, addr, zone, resolv): # pylint: disable=too-many-arguments """ Initialize the handler thread """ super(RequestHandler, self).__init__() self.daemon = True self.skt = skt self.ttl = ttl self.data = data self.addr = addr self.zone = zone self.resolv = resolv def run(self): """ Run the handler thread """ request = Message.from_bytes(self.data) answs, adds, auths = [], [], [] for req in request.questions: rrs = [ r for r in self.zone if r.match(type_=req.qtype, class_=req.qclass, name=req.qname)] if rrs != []: auths += rrs elif req.qtype in [Type.A, Type.CNAME] and req.qclass == Class.IN: name, cnames, addrs = self.resolv.gethostbyname(req.qname) if name != req.qname: answs.append(ResourceRecord( str(req.qname), Type.CNAME, Class.IN, self.ttl, CNAMERecordData(str(name)))) # pylint: disable=bad-continuation addrs = [ResourceRecord( name, Type.A, Class.IN, self.ttl, ARecordData(data)) for data in addrs] cnames = [ResourceRecord( name, Type.CNAME, Class.IN, self.ttl, CNAMERecordData(data)) for data in cnames] if req.qtype == Type.A: answs += addrs + cnames if req.qtype == Type.CNAME: answs += cnames header = Header( request.header.ident, 0, 0, len(answs), len(auths), len(adds)) response = Message(header, None, answs, auths, adds) self.skt.sendto(response.to_bytes(), self.addr) class Server(object): """ A recursive DNS server """ def __init__(self, port, caching, ttl): """ Initialize the server Args: port (int): port that server is listening on caching (bool): server uses resolver with caching if true ttl (int): ttl for records (if > 0) of cache """ self.caching = caching self.ttl = ttl self.port = port self.done = False self.zone = [] self.resolv = Resolver(True, self.ttl) self.skt = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) def serve(self): """ Start serving request """ self.skt.bind(('localhost', self.port)) while not self.done: data, addr = self.skt.recvfrom(512) reqh = RequestHandler( self.skt, self.ttl, data, addr, self.zone, self.resolv) reqh.start() def shutdown(self): """ Shutdown the server """ self.skt.close() self.done = True def parse_zone_file(self, fname): """Parse a zone file Will crash if the zone file has incorrect syntax. """ with open(fname) as zonef: zone = zonef.read() ttl, class_ = 3600000, Class.IN for match in re.finditer(rgx.ZONE_LINE_DOMAIN, zone, re.MULTILINE): match = match.groups() name = match[0][:-1] ttl = int(match[1] or match[4] or ttl * 1000) / 1000 class_ = Class.from_string( match[2] or match[3] or Class.to_string(class_)) type_ = Type.from_string(match[5]) data = match[6] if type_ == Type.A: cls = ARecordData elif type_ == Type.NS: cls = NSRecordData elif type_ == Type.CNAME: cls = CNAMERecordData else: continue record = ResourceRecord(name, type_, class_, ttl, cls(data)) self.zone.append(record)