[Git][NTPsec/ntpsec][master] 2 commits: Moved generally useful test jigs into jigs.py

Ian Bruene gitlab at mg.gitlab.com
Mon Aug 28 23:22:06 UTC 2017


Ian Bruene pushed to branch master at NTPsec / ntpsec


Commits:
e74c8f0c by Ian Bruene at 2017-08-28T16:24:02-05:00
Moved generally useful test jigs into jigs.py

- - - - -
f54d6919 by Ian Bruene at 2017-08-28T18:20:28-05:00
Added test for canonicalize_dns(), added necessary support to test jigs.

- - - - -


3 changed files:

- + tests/pylib/jigs.py
- tests/pylib/test_packet.py
- tests/pylib/test_util.py


Changes:

=====================================
tests/pylib/jigs.py
=====================================
--- /dev/null
+++ b/tests/pylib/jigs.py
@@ -0,0 +1,192 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from __future__ import print_function, division
+
+import socket
+import select
+
+class FileJig:
+    def __init__(self):
+        self.data = []
+        self.flushed = False
+        self.readline_return = [""]
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, *args):
+        return False
+
+    def __iter__(self):
+        return self.readline_return.__iter__()
+
+    def write(self, data):
+        self.data.append(data)
+        self.flushed = False
+
+    def flush(self):
+        self.flushed = True
+
+    def readline(self):
+        if len(self.readline_return) > 0:
+            return self.readline_return.pop(0)
+        return ""
+
+
+class SocketJig:
+    def __init__(self):
+        self.data = []
+        self.return_data = []
+        self.closed = False
+        self.connected = None
+        self.fail_connect = False
+        self.fail_send = 0
+
+    def sendall(self, data):
+        if self.fail_send > 0:
+            self.fail_send -= 1
+            raise socket.error()
+        self.data.append(data)
+
+    def close(self):
+        self.closed = True
+
+    def connect(self, addr):
+        if self.fail_connect is True:
+            err = socket.error()
+            err.strerror = "socket!"
+            err.errno = 16
+            raise err
+        self.connected = addr
+
+    def recv(self, bytecount):
+        if len(self.return_data) > 0:
+            current = self.return_data.pop(0)
+            if len(current) > bytecount:
+                ret = current[:bytecount]
+                current = current[bytecount:]
+                self.return_data.insert(0, current)  # push unwanted data
+                return ret
+            else:
+                return current
+        return None
+
+
+class HasherJig:
+    def __init__(self):
+        self.update_calls = []
+        self.digest_size = 0
+
+    def update(self, data):
+        self.update_calls.append(data)
+        if (data is not None) and (data != "None"):
+            self.digest_size += 1
+
+    def digest(self):
+        return "blah" * 4  # 16 byte hash
+
+
+class SocketModuleJig:
+    error = socket.error
+    gaierror = socket._socket.gaierror
+    SOCK_DGRAM = socket.SOCK_DGRAM
+    IPPROTO_UDP = socket.IPPROTO_UDP
+    AF_UNSPEC = socket.AF_UNSPEC
+    AI_NUMERICHOST = socket.AI_NUMERICHOST
+    AI_CANONNAME = socket.AI_CANONNAME
+    EAI_NONAME = socket.EAI_NONAME
+    EAI_NODATA = socket.EAI_NODATA
+    NI_NAMEREQD = socket.NI_NAMEREQD
+
+    def __init__(self):
+        self.gai_calls = []
+        self.gai_error_count = 0
+        self.gai_returns = []
+        self.gni_calls = []
+        self.gni_error_count = 0
+        self.gni_returns = []
+        self.socket_calls = []
+        self.socket_fail = False
+        self.socket_fail_connect = False
+        self.socketsReturned = []
+        self.inet_ntop_calls = []
+
+    def getaddrinfo(self, host, port, family=None, socktype=None,
+                    proto=None, flags=None):
+        self.gai_calls.append((host, port, family, socktype, proto, flags))
+        if self.gai_error_count > 0:
+            self.gai_error_count -= 1
+            err = self.gaierror("blah")
+            err.errno = socket.EAI_NONAME
+            raise err
+        return self.gai_returns.pop(0)
+
+    def getnameinfo(self, addr, flags):
+        self.gni_calls.append((addr, flags))
+        if self.gni_error_count > 0:
+            self.gni_error_count -= 1
+            err = self.gaierror("blah")
+            err.errno = socket.EAI_NONAME
+            raise err
+        return self.gni_returns.pop(0)
+
+    def socket(self, family, socktype, protocol):
+        self.socket_calls.append((family, socktype, protocol))
+        if self.socket_fail is True:
+            err = self.error()
+            err.strerror = "error!"
+            err.errno = 23
+            raise err
+        sock = SocketJig()
+        if self.socket_fail_connect is True:
+            sock.fail_connect = True
+        self.socketsReturned.append(sock)
+        return sock
+
+    def inet_ntop(self, addr, family):
+        self.inet_ntop_calls.append((addr, family))
+        return "canon.com"
+
+
+class GetpassModuleJig:
+    def __init__(self):
+        self.getpass_calls = []
+
+    def getpass(self, prompt, stream=None):
+        self.getpass_calls.append((prompt, stream))
+        return "xyzzy"
+
+
+class HashlibModuleJig:
+    def __init__(self):
+        self.new_calls = []
+        self.hashers_returned = []
+
+    def new(self, name):
+        self.new_calls.append(name)
+        h = HasherJig()
+        self.hashers_returned.append(h)
+        return h
+
+
+class SelectModuleJig:
+    error = select.error
+
+    def __init__(self):
+        self.select_calls = []
+        self.select_fail = 0
+        self.do_return = []
+
+    def select(self, ins, outs, excepts, timeout=0):
+        self.select_calls.append((ins, outs, excepts, timeout))
+        if self.select_fail > 0:
+            self.select_fail -= 1
+            raise select.error
+        if len(self.do_return) == 0:  # simplify code that doesn't need it
+            self.do_return.append(True)
+        doreturn = self.do_return.pop(0)
+        if doreturn is True:
+            return (ins, [], [])
+        else:
+            return ([], [], [])


=====================================
tests/pylib/test_packet.py
=====================================
--- a/tests/pylib/test_packet.py
+++ b/tests/pylib/test_packet.py
@@ -13,77 +13,13 @@ import select
 import sys
 import getpass
 
+from jigs import *
+
 odict = ntp.util.OrderedDict
 
 ntpp = ntp.packet
 ctlerr = ntp.packet.ControlException
 
-class FileJig:
-    def __init__(self):
-        self.data = []
-        self.flushed = False
-        self.readline_return = [""]
-
-    def __enter__(self):
-        return self
-
-    def __exit__(self, *args):
-        return False
-
-    def __iter__(self):
-        return self.readline_return.__iter__()
-
-    def write(self, data):
-        self.data.append(data)
-        self.flushed = False
-
-    def flush(self):
-        self.flushed = True
-
-    def readline(self):
-        if len(self.readline_return) > 0:
-            return self.readline_return.pop(0)
-        return ""
-
-
-class SocketJig:
-    def __init__(self):
-        self.data = []
-        self.return_data = []
-        self.closed = False
-        self.connected = None
-        self.fail_connect = False
-        self.fail_send = 0
-
-    def sendall(self, data):
-        if self.fail_send > 0:
-            self.fail_send -= 1
-            raise socket.error()
-        self.data.append(data)
-
-    def close(self):
-        self.closed = True
-
-    def connect(self, addr):
-        if self.fail_connect is True:
-            err = socket.error()
-            err.strerror = "socket!"
-            err.errno = 16
-            raise err
-        self.connected = addr
-
-    def recv(self, bytecount):
-        if len(self.return_data) > 0:
-            current = self.return_data.pop(0)
-            if len(current) > bytecount:
-                ret = current[:bytecount]
-                current = current[bytecount:]
-                self.return_data.insert(0, current)  # push unwanted data
-                return ret
-            else:
-                return current
-        return None
-
 
 class SessionJig:
     def __init__(self):
@@ -113,111 +49,6 @@ class ControlPacketJig:
         return self.extension
 
 
-class HasherJig:
-    def __init__(self):
-        self.update_calls = []
-        self.digest_size = 0
-
-    def update(self, data):
-        self.update_calls.append(data)
-        if (data is not None) and (data != "None"):
-            self.digest_size += 1
-
-    def digest(self):
-        return "blah" * 4  # 16 byte hash
-
-
-class SocketModuleJig:
-    error = socket.error
-    gaierror = socket._socket.gaierror
-    SOCK_DGRAM = socket.SOCK_DGRAM
-    IPPROTO_UDP = socket.IPPROTO_UDP
-    AF_UNSPEC = socket.AF_UNSPEC
-    AI_NUMERICHOST = socket.AI_NUMERICHOST
-    AI_CANONNAME = socket.AI_CANONNAME
-    EAI_NONAME = socket.EAI_NONAME
-    EAI_NODATA = socket.EAI_NODATA
-
-    def __init__(self):
-        self.gai_calls = []
-        self.gai_error_count = 0
-        self.socket_calls = []
-        self.socket_fail = False
-        self.socket_fail_connect = False
-        self.socketsReturned = []
-        self.inet_ntop_calls = []
-
-    def getaddrinfo(self, host, port, family=None, socktype=None,
-                    proto=None, flags=None):
-        self.gai_calls.append((host, port, family, socktype, proto, flags))
-        if self.gai_error_count > 0:
-            self.gai_error_count -= 1
-            err = self.gaierror("blah")
-            err.errno = socket.EAI_NONAME
-            raise err
-        return 42
-
-    def socket(self, family, socktype, protocol):
-        self.socket_calls.append((family, socktype, protocol))
-        if self.socket_fail is True:
-            err = self.error()
-            err.strerror = "error!"
-            err.errno = 23
-            raise err
-        sock = SocketJig()
-        if self.socket_fail_connect is True:
-            sock.fail_connect = True
-        self.socketsReturned.append(sock)
-        return sock
-
-    def inet_ntop(self, addr, family):
-        self.inet_ntop_calls.append((addr, family))
-        return "canon.com"
-
-
-class GetpassModuleJig:
-    def __init__(self):
-        self.getpass_calls = []
-
-    def getpass(self, prompt, stream=None):
-        self.getpass_calls.append((prompt, stream))
-        return "xyzzy"
-
-
-class HashlibModuleJig:
-    def __init__(self):
-        self.new_calls = []
-        self.hashers_returned = []
-
-    def new(self, name):
-        self.new_calls.append(name)
-        h = HasherJig()
-        self.hashers_returned.append(h)
-        return h
-
-
-class SelectModuleJig:
-    error = select.error
-
-    def __init__(self):
-        self.select_calls = []
-        self.select_fail = 0
-        self.do_return = []
-
-    def select(self, ins, outs, excepts, timeout=0):
-        self.select_calls.append((ins, outs, excepts, timeout))
-        if self.select_fail > 0:
-            self.select_fail -= 1
-            raise select.error
-        if len(self.do_return) == 0:  # simplify code that doesn't need it
-            self.do_return.append(True)
-        doreturn = self.do_return.pop(0)
-        if doreturn is True:
-            return (ins, [], [])
-        else:
-            return ([], [], [])
-
-
 class AuthenticatorJig:
     compute_mac_calls = []
 
@@ -245,6 +76,7 @@ class AuthenticatorJig:
                                                    keytype, passwd))
         return "mac"
 
+
 # ==========================================================
 #  Tests
 # =========================================================
@@ -843,6 +675,7 @@ class TestControlSession(unittest.TestCase):
             cls.debug = 3
             cls.logfp = logjig
             # Test first type
+            fakesockmod.gai_returns = [42]
             result = cls._ControlSession__lookuphost("blah.com", "family")
             self.assertEqual(result, 42)
             self.assertEqual(fakesockmod.gai_calls,
@@ -853,6 +686,7 @@ class TestControlSession(unittest.TestCase):
             # Test second type
             logjig.__init__()  # reset
             fakesockmod.__init__()
+            fakesockmod.gai_returns = [42]
             fakesockmod.gai_error_count = 1
             result = cls._ControlSession__lookuphost("blah.com", "family")
             self.assertEqual(result, 42)
@@ -868,6 +702,7 @@ class TestControlSession(unittest.TestCase):
             # Test third type
             logjig.__init__()  # reset
             fakesockmod.__init__()
+            fakesockmod.gai_returns = [42]
             fakesockmod.gai_error_count = 2
             result = cls._ControlSession__lookuphost("blah.com", "family")
             self.assertEqual(result, 42)


=====================================
tests/pylib/test_util.py
=====================================
--- a/tests/pylib/test_util.py
+++ b/tests/pylib/test_util.py
@@ -4,6 +4,8 @@
 import unittest
 import ntp.util
 
+import jigs
+
 
 class TestPylibUtilMethods(unittest.TestCase):
 
@@ -404,5 +406,52 @@ class TestPylibUtilMethods(unittest.TestCase):
         finally:
             ntp.util.monoclock = monotemp
 
+    def test_canonicalize_dns(self):
+        f = ntp.util.canonicalize_dns
+
+        fakesockmod = jigs.SocketModuleJig()
+        mycache = ntp.util.Cache()
+        mycache.set("foo", "bar")
+        try:
+            cachetemp = ntp.util.canonicalization_cache
+            ntp.util.canonicalization_cache = mycache
+            sockettemp = ntp.util.socket
+            ntp.util.socket = fakesockmod
+            # Test cache hit
+            self.assertEqual(f("foo"), "bar")
+            self.assertEqual(fakesockmod.gai_calls, [])
+            # Test addrinfo fail
+            fakesockmod.__init__()
+            fakesockmod.gai_error_count = 1
+            self.assertEqual(f("none"), "DNSFAIL:none")
+            self.assertEqual(fakesockmod.gai_calls,
+                             [("none", None, 0, 0, 0, 2)])
+            self.assertEqual(fakesockmod.gni_calls, [])
+            # Test nameinfo fail
+            fakesockmod.__init__()
+            fakesockmod.gni_error_count = 1
+            fakesockmod.gni_returns = [("www.Hastur.madness", 42)]
+            fakesockmod.gai_returns = [(("family", "socktype", "proto",
+                                         "san.Hastur.madness", "42.23.%$.(#"),)]
+            self.assertEqual(f("bar:42"), "san.hastur.madness:42")
+            # Test nameinfo fail, no canonname
+            fakesockmod.__init__()
+            mycache.__init__()
+            fakesockmod.gni_error_count = 1
+            fakesockmod.gni_returns = [("www.Hastur.madness", 42)]
+            fakesockmod.gai_returns = [(("family", "socktype", "proto",
+                                         None, "42.23.%$.(#"),)]
+            self.assertEqual(f("bar:42"), "bar:42")
+            # Test success
+            fakesockmod.__init__()
+            mycache.__init__()
+            fakesockmod.gni_returns = [("www.Hastur.madness", 42)]
+            fakesockmod.gai_returns = [(("family", "socktype", "proto",
+                                         None, "42.23.%$.(#"),)]
+            self.assertEqual(f("bar:42"), "www.hastur.madness:42")
+        finally:
+            ntp.util.canonicalization_cache = cachetemp
+            ntp.util.socket = sockettemp
+
 if __name__ == '__main__':
     unittest.main()



View it on GitLab: https://gitlab.com/NTPsec/ntpsec/compare/3c374171f6a2528704a28af97f4d92e4f9674181...f54d6919c53317b87a233bc64d22a266272d9ae7

---
View it on GitLab: https://gitlab.com/NTPsec/ntpsec/compare/3c374171f6a2528704a28af97f4d92e4f9674181...f54d6919c53317b87a233bc64d22a266272d9ae7
You're receiving this email because of your account on gitlab.com.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <https://lists.ntpsec.org/pipermail/vc/attachments/20170828/4f1cdd8f/attachment.html>


More information about the vc mailing list