summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCamil Staps2016-05-24 18:16:17 +0200
committerCamil Staps2016-05-24 18:16:17 +0200
commitbc1d79113ad3fdfcf3319b3cc36f1a0253e64f9d (patch)
treea6ef59be1ec2fcda8763093711e20213b486de63
parentMade cli apps executable (diff)
Caching (mostly from da97de6)
-rw-r--r--project2/proj2_s4498062/.gitignore3
-rw-r--r--project2/proj2_s4498062/dns/cache.py55
-rw-r--r--project2/proj2_s4498062/dns/resolver.py48
-rw-r--r--project2/proj2_s4498062/dns/resource.py26
4 files changed, 85 insertions, 47 deletions
diff --git a/project2/proj2_s4498062/.gitignore b/project2/proj2_s4498062/.gitignore
index 94487b9..0191c0c 100644
--- a/project2/proj2_s4498062/.gitignore
+++ b/project2/proj2_s4498062/.gitignore
@@ -1 +1,2 @@
-*.pyc
+*.pyc
+.dns.cache
diff --git a/project2/proj2_s4498062/dns/cache.py b/project2/proj2_s4498062/dns/cache.py
index 3ef14b3..9cde66f 100644
--- a/project2/proj2_s4498062/dns/cache.py
+++ b/project2/proj2_s4498062/dns/cache.py
@@ -7,6 +7,7 @@ It is highly recommended to use these.
"""
import json
+import time
from dns.resource import ResourceRecord, RecordData
from dns.types import Type
@@ -22,11 +23,12 @@ class ResourceEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, ResourceRecord):
return {
- "name": obj.name,
- "type": Type.to_string(obj.type_),
- "class": Class.to_string(obj.class_),
- "ttl": obj.ttl,
- "rdata": obj.rdata.data
+ "name": obj.name,
+ "type": Type.to_string(obj.type_),
+ "class": Class.to_string(obj.class_),
+ "ttl": obj.ttl,
+ "rdata": obj.rdata.data,
+ "timestamp": obj.timestamp
}
return json.JSONEncoder.default(self, obj)
@@ -42,20 +44,33 @@ def resource_from_json(dct):
class_ = Class.from_string(dct["class"])
ttl = dct["ttl"]
rdata = RecordData.create(type_, dct["rdata"])
- return ResourceRecord(name, type_, class_, ttl, rdata)
+ timestamp = dct["timestamp"]
+ return ResourceRecord(name, type_, class_, ttl, rdata, timestamp)
class RecordCache(object):
""" Cache for ResourceRecords """
- def __init__(self, ttl):
+ FILE = '.dns.cache'
+
+ def __init__(self):
""" Initialize the RecordCache
Args:
ttl (int): TTL of cached entries (if > 0)
"""
self.records = []
- self.ttl = ttl
+ self.read_cache_file()
+
+ def __del__(self):
+ self.write_cache_file()
+
+ def remove_old(self):
+ """Remove entries for which the TTL has expired"""
+ now = int(time.clock())
+ for record in reversed(self.records):
+ if record.ttl + record.timestamp < now:
+ self.records.remove(record)
def lookup(self, dname, type_, class_):
""" Lookup resource records in cache
@@ -68,7 +83,16 @@ class RecordCache(object):
type_ (Type): type
class_ (Class): class
"""
- pass
+ self.remove_old()
+ return [
+ r for r in self.records
+ if r.match(name=dname, type_=type_, class_=class_)]
+
+ def add_records_from(self, msg):
+ for record in msg.answers + msg.authorities + msg.additionals:
+ if record.type_ in [Type.A, Type.AAAA, Type.CNAME, Type.NS] and \
+ record.class_ == Class.IN:
+ self.add_record(record)
def add_record(self, record):
""" Add a new Record to the cache
@@ -76,12 +100,19 @@ class RecordCache(object):
Args:
record (ResourceRecord): the record added to the cache
"""
- pass
+ self.records.append(record)
def read_cache_file(self):
""" Read the cache file from disk """
- pass
+ try:
+ with open(self.FILE, 'r') as jsonfile:
+ self.records = json.load(
+ jsonfile, object_hook=resource_from_json)
+ except IOError:
+ pass
def write_cache_file(self):
""" Write the cache file to disk """
- pass
+ self.remove_old()
+ with open(self.FILE, 'w') as jsonfile:
+ json.dump(self.records, jsonfile, cls=ResourceEncoder, indent=4)
diff --git a/project2/proj2_s4498062/dns/resolver.py b/project2/proj2_s4498062/dns/resolver.py
index ffb7c16..fe46492 100644
--- a/project2/proj2_s4498062/dns/resolver.py
+++ b/project2/proj2_s4498062/dns/resolver.py
@@ -13,7 +13,7 @@ import socket
from dns.classes import Class
from dns.types import Type
-import dns.cache
+from dns.cache import RecordCache
from dns.message import Message, Question, Header
import dns.rcodes
@@ -48,8 +48,16 @@ class Resolver(object):
self.ttl = ttl
self.timeout = timeout
- def do_query(self, hint, hostname, type_, class_=Class.IN):
+ if self.caching:
+ self.cache = RecordCache()
+
+ def do_query(self, hint, hostname, type_, class_=Class.IN, caching=True):
"""Do a query to a hint"""
+ if self.caching and caching:
+ records = self.cache.lookup(hostname, type_, class_)
+ if records != []:
+ return records
+
ident = randint(0, 65535)
header = Header(ident, 0, 1, 0, 0, 0)
header.qr = 0
@@ -65,19 +73,22 @@ class Resolver(object):
data = sock.recv(512)
resp = Message.from_bytes(data)
if resp.header.ident == ident:
- return resp
+ if self.caching and caching:
+ self.cache.add_records_from(resp)
+ return resp.answers + resp.authorities + resp.additionals
except socket.timeout:
pass
- return None
+ return []
- def do_query_to_multiple(self, hints, hostname, type_, class_=Class.IN):
+ def do_query_to_multiple(
+ self, hints, hostname, type_, class_=Class.IN, caching=True):
"""Do a query to multiple hints, return the remaining hints"""
while hints != []:
hint = hints.pop()
- response = self.do_query(hint, hostname, type_, class_)
+ response = self.do_query(hint, hostname, type_, class_, caching)
if response is not None:
return hints, response
- return [], None
+ return [], []
def gethostbyname(self, hostname):
""" Translate a host name to IPv4 address.
@@ -101,11 +112,7 @@ class Resolver(object):
domain = domains.pop(-1)
aliases = []
while hints != []:
- hints, resp = self.do_query_to_multiple(hints, domain, Type.A)
- if resp == None:
- continue
-
- info = resp.answers + resp.authorities + resp.additionals
+ hints, info = self.do_query_to_multiple(hints, domain, Type.A)
aliases += [
r.rdata.data for r in info
@@ -117,7 +124,8 @@ class Resolver(object):
if r.match(type_=Type.A, class_=Class.IN, name=domain)]
if ips != []:
return hostname, aliases, ips
- # Case 2: name servers
+
+ # Case 2: name servers for this domain
auths = [
r.rdata.data for r in info
if r.match(type_=Type.NS, class_=Class.IN, name=domain)]
@@ -135,21 +143,23 @@ class Resolver(object):
if domain != hostname:
domain = domains.pop(-1) + '.' + domain
continue
- # Case 3: delegation to other name servers
+
+ # Case 3: name servers for the same domain
parent = '.'.join(domain.split('.')[1:])
- refs = [
+ auths = [
r.rdata.data for r in info
if r.match(type_=Type.NS, class_=Class.IN, name=parent)]
ips = [
- add.rdata.data for ns in refs for add in info
+ add.rdata.data for ns in auths for add in info
if add.match(name=ns, type_=Type.A)]
if ips != []:
hints += ips
continue
- if refs != []:
- refs = [h for r in refs for h in self.gethostbyname(r)[2]]
- hints += refs
+ if auths != []:
+ auths = [h for r in auths for h in self.gethostbyname(r)[2]]
+ hints += auths
continue
+
# Case 4: aliases
for alias in aliases:
_, extra_aliases, alias_addresses = self.gethostbyname(alias)
diff --git a/project2/proj2_s4498062/dns/resource.py b/project2/proj2_s4498062/dns/resource.py
index 89201ec..b1c8ae4 100644
--- a/project2/proj2_s4498062/dns/resource.py
+++ b/project2/proj2_s4498062/dns/resource.py
@@ -1,5 +1,3 @@
-#!/usr/bin/env python2
-
""" A DNS resource record
This class contains classes for DNS resource records and record data. This
@@ -9,13 +7,14 @@ of your resolver and server.
import socket
import struct
+import time
from dns.types import Type
class ResourceRecord(object):
""" DNS resource record """
- def __init__(self, name, type_, class_, ttl, rdata):
+ def __init__(self, name, type_, class_, ttl, rdata, timestamp=time.time()):
""" Create a new resource record
Args:
@@ -29,6 +28,7 @@ class ResourceRecord(object):
self.class_ = class_
self.ttl = ttl
self.rdata = rdata
+ self.timestamp = timestamp
def match(self, name=None, type_=None, class_=None, ttl=None):
"""Check if the record matches properties"""
@@ -51,16 +51,12 @@ class ResourceRecord(object):
""" Convert ResourceRecord from bytes """
names, offset = parser.from_bytes(packet, offset, 1)
name = names[0]
- type_, class_, ttl, rdlength = struct.unpack_from(
- "!HHIH", packet, offset)
+ type_, class_, ttl, rdlen = struct.unpack_from("!HHIH", packet, offset)
offset += 10
- rdata = RecordData.from_bytes(type_, packet, offset, rdlength, parser)
- offset += rdlength
+ rdata = RecordData.from_bytes(type_, packet, offset, rdlen, parser)
+ offset += rdlen
return cls(name, type_, class_, ttl, rdata), offset
- def __repr__(self):
- return ' '.join(map(str, [self.name, self.type_, self.rdata.data]))
-
class RecordData(object):
""" Record Data """
@@ -121,7 +117,7 @@ class RecordData(object):
class ARecordData(RecordData):
- """ Record data for A type """
+ """Data of an A record"""
def to_bytes(self, offset, composer):
""" Convert to bytes
@@ -147,7 +143,7 @@ class ARecordData(RecordData):
class CNAMERecordData(RecordData):
- """ Record data for CNAME type """
+ """Data of a CNAME record"""
def to_bytes(self, offset, composer):
""" Convert to bytes
@@ -174,7 +170,7 @@ class CNAMERecordData(RecordData):
class NSRecordData(RecordData):
- """ Record data for NS type """
+ """Data of an NS record"""
def to_bytes(self, offset, composer):
""" Convert to bytes
@@ -201,7 +197,7 @@ class NSRecordData(RecordData):
class AAAARecordData(RecordData):
- """ Record data for AAAA type """
+ """Data of an AAAA record"""
def to_bytes(self, offset, composer):
""" Convert to bytes
@@ -227,7 +223,7 @@ class AAAARecordData(RecordData):
class GenericRecordData(RecordData):
- """ Generic Record Data (for other types) """
+ """Data of a generic record"""
def to_bytes(self, offset, composer):
""" Convert to bytes