summaryrefslogtreecommitdiff
path: root/project2/proj2_s4498062/dns/resolver.py
diff options
context:
space:
mode:
Diffstat (limited to 'project2/proj2_s4498062/dns/resolver.py')
-rw-r--r--project2/proj2_s4498062/dns/resolver.py61
1 files changed, 44 insertions, 17 deletions
diff --git a/project2/proj2_s4498062/dns/resolver.py b/project2/proj2_s4498062/dns/resolver.py
index b57e044..af60686 100644
--- a/project2/proj2_s4498062/dns/resolver.py
+++ b/project2/proj2_s4498062/dns/resolver.py
@@ -8,6 +8,7 @@ client and the DNS server, but with a different list of servers.
import re
import socket
+from dns.cache import RecordCache
from dns.classes import Class
from dns.message import Message, Header, Question
from dns.types import Type
@@ -45,6 +46,9 @@ class Resolver(object):
self.caching = caching
self.ttl = ttl
+ if self.caching:
+ self.cache = RecordCache()
+
def do_query(self, query, using):
"""Send a query to a list of name servers"""
for hint in using:
@@ -56,10 +60,27 @@ class Resolver(object):
sock.sendto(query.to_bytes(), (hint, 53))
data = sock.recv(512)
response = Message.from_bytes(data)
+
+ if self.caching:
+ for record in response.answers + \
+ response.authorities + \
+ response.additionals:
+ self.cache.add_record(record)
+
yield response
except socket.timeout:
pass
+ def try_hint(self, sub, dom, hints):
+ """Helper for get_hints"""
+ if sub == '':
+ for hint in hints:
+ yield hint
+ else:
+ for new_hint in self.get_hints(
+ sub, dom, hints, in_recursion=True):
+ yield new_hint
+
def get_hints(self, domain, parent='', using=None, in_recursion=False):
"""Get a list of nameservers for a domain"""
if using is None:
@@ -68,13 +89,21 @@ class Resolver(object):
if not in_recursion:
using += self.nameservers
+ print 'Trying', domain, parent, 'using', using
+
domains = re.match(rgx.DOMAIN, domain)
if domains is None:
- return None
+ return
sub, dom = domains.groups()
if parent != '':
dom += '.' + parent
+ if self.caching:
+ hints = self.cache.lookup(dom, Type.NS, Class.IN)
+ for hint in self.try_hint(
+ sub, dom, [r.rdata.data for r in hints]):
+ yield hint
+
header = Header(0, 0, 1, 0, 0, 0)
header.qr = 0
header.opcode = 0
@@ -82,26 +111,13 @@ class Resolver(object):
query = Message(header, [Question(dom, Type.NS, Class.IN)])
for response in self.do_query(query, using):
- new_hints = [ip for _, [ip] in list(response.get_hints())]
-
- if new_hints != []:
- if sub is '':
- return new_hints + using
-
- result = self.get_hints(
- sub, dom, new_hints + using, in_recursion=True)
- if result is not None:
- return result
-
- return []
+ hints = [ip for _, [ip] in response.get_hints()]
+ for hint in self.try_hint(sub, dom, hints):
+ yield hint
def gethostbyname(self, hostname):
""" Translate a host name to IPv4 address.
- Currently this method contains an example. You will have to replace
- this example with example with the algorithm described in section
- 5.3.3 in RFC 1034.
-
Args:
hostname (str): the hostname to resolve
@@ -111,6 +127,15 @@ class Resolver(object):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(self.timeout)
+ if self.caching:
+ addrs = list(self.cache.lookup(hostname, Type.A, Class.IN))
+ cnames = list(self.cache.lookup(hostname, Type.CNAME, Class.IN))
+ if addrs != []:
+ return (
+ hostname,
+ [r.rdata.data for r in cnames],
+ [r.rdata.data for r in addrs])
+
# Create and send query
question = Question(hostname, Type.A, Class.IN)
header = Header(9001, 0, 1, 0, 0, 0)
@@ -124,3 +149,5 @@ class Resolver(object):
addresses = response.get_addresses()
return hostname, aliases, addresses
+
+ return hostname, [], []