diff options
Diffstat (limited to 'project2/proj2_s4498062/dns/resolver.py')
-rw-r--r-- | project2/proj2_s4498062/dns/resolver.py | 61 |
1 files changed, 44 insertions, 17 deletions
diff --git a/project2/proj2_s4498062/dns/resolver.py b/project2/proj2_s4498062/dns/resolver.py index b57e044..af60686 100644 --- a/project2/proj2_s4498062/dns/resolver.py +++ b/project2/proj2_s4498062/dns/resolver.py @@ -8,6 +8,7 @@ client and the DNS server, but with a different list of servers. import re import socket +from dns.cache import RecordCache from dns.classes import Class from dns.message import Message, Header, Question from dns.types import Type @@ -45,6 +46,9 @@ class Resolver(object): self.caching = caching self.ttl = ttl + if self.caching: + self.cache = RecordCache() + def do_query(self, query, using): """Send a query to a list of name servers""" for hint in using: @@ -56,10 +60,27 @@ class Resolver(object): sock.sendto(query.to_bytes(), (hint, 53)) data = sock.recv(512) response = Message.from_bytes(data) + + if self.caching: + for record in response.answers + \ + response.authorities + \ + response.additionals: + self.cache.add_record(record) + yield response except socket.timeout: pass + def try_hint(self, sub, dom, hints): + """Helper for get_hints""" + if sub == '': + for hint in hints: + yield hint + else: + for new_hint in self.get_hints( + sub, dom, hints, in_recursion=True): + yield new_hint + def get_hints(self, domain, parent='', using=None, in_recursion=False): """Get a list of nameservers for a domain""" if using is None: @@ -68,13 +89,21 @@ class Resolver(object): if not in_recursion: using += self.nameservers + print 'Trying', domain, parent, 'using', using + domains = re.match(rgx.DOMAIN, domain) if domains is None: - return None + return sub, dom = domains.groups() if parent != '': dom += '.' + parent + if self.caching: + hints = self.cache.lookup(dom, Type.NS, Class.IN) + for hint in self.try_hint( + sub, dom, [r.rdata.data for r in hints]): + yield hint + header = Header(0, 0, 1, 0, 0, 0) header.qr = 0 header.opcode = 0 @@ -82,26 +111,13 @@ class Resolver(object): query = Message(header, [Question(dom, Type.NS, Class.IN)]) for response in self.do_query(query, using): - new_hints = [ip for _, [ip] in list(response.get_hints())] - - if new_hints != []: - if sub is '': - return new_hints + using - - result = self.get_hints( - sub, dom, new_hints + using, in_recursion=True) - if result is not None: - return result - - return [] + hints = [ip for _, [ip] in response.get_hints()] + for hint in self.try_hint(sub, dom, hints): + yield hint def gethostbyname(self, hostname): """ Translate a host name to IPv4 address. - Currently this method contains an example. You will have to replace - this example with example with the algorithm described in section - 5.3.3 in RFC 1034. - Args: hostname (str): the hostname to resolve @@ -111,6 +127,15 @@ class Resolver(object): sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.settimeout(self.timeout) + if self.caching: + addrs = list(self.cache.lookup(hostname, Type.A, Class.IN)) + cnames = list(self.cache.lookup(hostname, Type.CNAME, Class.IN)) + if addrs != []: + return ( + hostname, + [r.rdata.data for r in cnames], + [r.rdata.data for r in addrs]) + # Create and send query question = Question(hostname, Type.A, Class.IN) header = Header(9001, 0, 1, 0, 0, 0) @@ -124,3 +149,5 @@ class Resolver(object): addresses = response.get_addresses() return hostname, aliases, addresses + + return hostname, [], [] |