[Git][NTPsec/ntpsec][master] packet.py tests now work with Python 3

Ian Bruene gitlab at mg.gitlab.com
Thu Sep 21 22:58:26 UTC 2017


Ian Bruene pushed to branch master at NTPsec / ntpsec


Commits:
9aa67300 by Ian Bruene at 2017-09-21T22:57:09Z
packet.py tests now work with Python 3

- - - - -


3 changed files:

- pylib/packet.py
- tests/pylib/jigs.py
- tests/pylib/test_packet.py


Changes:

=====================================
pylib/packet.py
=====================================
--- a/pylib/packet.py
+++ b/pylib/packet.py
@@ -978,6 +978,7 @@ class ControlSession:
         # If we have data, pad it out to a 32-bit boundary.
         # Do not include these in the payload count.
         if pkt.extension:
+            pkt.extension = polybytes(pkt.extension)
             while ((ControlPacket.HEADER_LEN + len(pkt.extension)) & 3):
                 pkt.extension += b"\x00"
 
@@ -1001,7 +1002,7 @@ class ControlSession:
         if mac is None:
             raise ControlException(SERR_NOKEY)
         else:
-            pkt.extension += mac
+            pkt.extension += polybytes(mac)
         return pkt.send()
 
     def getresponse(self, opcode, associd, timeo):
@@ -1332,7 +1333,7 @@ class ControlSession:
         elif b"\x00" in self.response:
             self.response = self.response[:self.response.index(b"\x00")]
         self.response = self.response.rstrip()
-        return self.response == "Config Succeeded"
+        return self.response == polybytes("Config Succeeded")
 
     def fetch_nonce(self):
         """
@@ -1349,7 +1350,11 @@ This combats source address spoofing
 
         # uh, oh, no nonce seen
         # this print probably never can be seen...
-        self.logfp.write("## Nonce expected: %s" % self.response)
+        if str is bytes:
+            resp = self.response
+        else:
+            resp = self.response.decode()
+        self.logfp.write("## Nonce expected: %s" % resp)
         raise ControlException(SERR_BADNONCE)
 
     def mrulist(self, variables=None, rawhook=None, direct=None):
@@ -1714,6 +1719,6 @@ class Authenticator:
         hasher = hashlib.new(keytype)
         hasher.update(passwd)
         hasher.update(payload)
-        return hasher.digest() == mac
+        return polybytes(hasher.digest()) == mac
 
 # end


=====================================
tests/pylib/jigs.py
=====================================
--- a/tests/pylib/jigs.py
+++ b/tests/pylib/jigs.py
@@ -8,6 +8,84 @@ import select
 import os.path
 
 
+master_encoding = 'latin-1'
+
+if str is bytes:  # Python 2
+    polystr = str
+    polybytes = bytes
+    polyord = ord
+    polychr = str
+    input = raw_input
+
+    def string_escape(s):
+        return s.decode('string_escape')
+
+    def make_wrapper(fp):
+        return fp
+
+else:  # Python 3
+    import io
+
+    def polystr(o):
+        "Polymorphic string factory function"
+        if isinstance(o, str):
+            return o
+        if not isinstance(o, bytes):
+            return str(o)
+        return str(o, encoding=master_encoding)
+
+    def polybytes(s):
+        "Polymorphic string encoding function"
+        if isinstance(s, bytes):
+            return s
+        if not isinstance(s, str):
+            return bytes(s)
+        return bytes(s, encoding=master_encoding)
+
+    def polyord(c):
+        "Polymorphic ord() function"
+        if isinstance(c, str):
+            return ord(c)
+        else:
+            return c
+
+    def polychr(c):
+        "Polymorphic chr() function"
+        if isinstance(c, int):
+            return chr(c)
+        else:
+            return c
+
+    def string_escape(s):
+        "Polymorphic string_escape/unicode_escape"
+        # This hack is necessary because Unicode strings in Python 3 don't
+        # have a decode method, so there's no simple way to ask it for the
+        # equivalent of decode('string_escape') in Python 2. This function
+        # assumes that it will be called with a Python 3 'str' instance
+        return s.encode(master_encoding).decode('unicode_escape')
+
+    def make_wrapper(fp):
+        "Wrapper factory function to enforce master encoding"
+        # This can be used to wrap normally binary streams for API
+        # compatibility with functions that need a text stream in
+        # Python 3; it ensures that the binary bytes are decoded using
+        # the master encoding we use to turn bytes to Unicode in
+        # polystr above
+        # newline="\n" ensures that Python 3 won't mangle line breaks
+        return io.TextIOWrapper(fp, encoding=master_encoding, newline="\n")
+
+    def make_std_wrapper(stream):
+        "Standard input/output wrapper factory function"
+        # This ensures that the encoding of standard output and standard
+        # error on Python 3 matches the master encoding we use to turn
+        # bytes to Unicode in polystr above
+        # line_buffering=True ensures that interactive
+        # command sessions work as expected
+        return io.TextIOWrapper(stream.buffer,
+                                encoding=master_encoding, newline="\n",
+                                line_buffering=True)
+
+
 class FileJig:
     def __init__(self, returns=[""]):
         self.data = []
@@ -87,11 +165,11 @@ class HasherJig:
 
     def update(self, data):
         self.update_calls.append(data)
-        if (data is not None) and (data != "None"):
+        if len(data) > 0:
             self.digest_size += 1
 
     def digest(self):
-        return "blah" * 4  # 16 byte hash
+        return polybytes("blah" * 4)  # 16 byte hash
 
 
 class SocketModuleJig:


=====================================
tests/pylib/test_packet.py
=====================================
--- a/tests/pylib/test_packet.py
+++ b/tests/pylib/test_packet.py
@@ -19,6 +19,83 @@ odict = ntp.util.OrderedDict
 ntpp = ntp.packet
 ctlerr = ntp.packet.ControlException
 
+master_encoding = 'latin-1'
+
+if str is bytes:  # Python 2
+    polystr = str
+    polybytes = bytes
+    polyord = ord
+    polychr = str
+    input = raw_input
+
+    def string_escape(s):
+        return s.decode('string_escape')
+
+    def make_wrapper(fp):
+        return fp
+
+else:  # Python 3
+    import io
+
+    def polystr(o):
+        "Polymorphic string factory function"
+        if isinstance(o, str):
+            return o
+        if not isinstance(o, bytes):
+            return str(o)
+        return str(o, encoding=master_encoding)
+
+    def polybytes(s):
+        "Polymorphic string encoding function"
+        if isinstance(s, bytes):
+            return s
+        if not isinstance(s, str):
+            return bytes(s)
+        return bytes(s, encoding=master_encoding)
+
+    def polyord(c):
+        "Polymorphic ord() function"
+        if isinstance(c, str):
+            return ord(c)
+        else:
+            return c
+
+    def polychr(c):
+        "Polymorphic chr() function"
+        if isinstance(c, int):
+            return chr(c)
+        else:
+            return c
+
+    def string_escape(s):
+        "Polymorphic string_escape/unicode_escape"
+        # This hack is necessary because Unicode strings in Python 3 don't
+        # have a decode method, so there's no simple way to ask it for the
+        # equivalent of decode('string_escape') in Python 2. This function
+        # assumes that it will be called with a Python 3 'str' instance
+        return s.encode(master_encoding).decode('unicode_escape')
+
+    def make_wrapper(fp):
+        "Wrapper factory function to enforce master encoding"
+        # This can be used to wrap normally binary streams for API
+        # compatibility with functions that need a text stream in
+        # Python 3; it ensures that the binary bytes are decoded using
+        # the master encoding we use to turn bytes to Unicode in
+        # polystr above
+        # newline="\n" ensures that Python 3 won't mangle line breaks
+        return io.TextIOWrapper(fp, encoding=master_encoding, newline="\n")
+
+    def make_std_wrapper(stream):
+        "Standard input/output wrapper factory function"
+        # This ensures that the encoding of standard output and standard
+        # error on Python 3 matches the master encoding we use to turn
+        # bytes to Unicode in polystr above
+        # line_buffering=True ensures that interactive
+        # command sessions work as expected
+        return io.TextIOWrapper(stream.buffer,
+                                encoding=master_encoding, newline="\n",
+                                line_buffering=True)
+
 
 class SessionJig:
     def __init__(self):
@@ -141,7 +218,7 @@ class TestSyncPacket(unittest.TestCase):
         self.assertEqual(cls.origin_timestamp, 0)
         self.assertEqual(cls.receive_timestamp, 0)
         self.assertEqual(cls.transmit_timestamp, 0)
-        self.assertEqual(cls.extension, '')
+        self.assertEqual(cls.extension, polybytes(''))
         self.assertEqual(cls.extfields, [])
         self.assertEqual(cls.mac, '')
         self.assertEqual(cls.hostname, None)
@@ -212,10 +289,11 @@ class TestSyncPacket(unittest.TestCase):
         self.assertEqual(cls.receive_timestamp, 0x0201020304050607)
         self.assertEqual(cls.transmit_timestamp, 0x0301020304050607)
         self.assertEqual(cls.extfields,
-                         [(1, "blah"), (2, "jabberjabber"),
-                          (3, "In the end, our choices make us.")])
-        self.assertEqual(cls.extension, ext + mac)
-        self.assertEqual(cls.mac, mac)
+                         [(1, polybytes("blah")),
+                          (2, polybytes("jabberjabber")),
+                          (3, polybytes("In the end, our choices make us."))])
+        self.assertEqual(cls.extension, polybytes(ext + mac))
+        self.assertEqual(cls.mac, polybytes(mac))
         # Test with extension, DES
         data2 = data + ext + "\x11\x22\x33\x44\x55\x66\x77\x88\x99\xAA\xBB\xCC"
         try:
@@ -246,14 +324,14 @@ class TestSyncPacket(unittest.TestCase):
               "\x0A\x0B\x0C\x0D\x0E\x0F\x10\x11\x12\x13"
         data2 = data + ext + mac
         cls = self.target(data2)
-        self.assertEqual(cls.mac, mac)
-        self.assertEqual(cls.extension, ext + mac)
+        self.assertEqual(cls.mac, polybytes(mac))
+        self.assertEqual(cls.extension, polybytes(ext + mac))
         # Test with extension, MD5 or SHA1, 24
         mac += "\x14\x15\x16\x17"
         data2 = data + ext + mac
-        cls = self.target(data2)
-        self.assertEqual(cls.mac, mac)
-        self.assertEqual(cls.extension, ext + mac)
+        cls = self.target(polybytes(data2))
+        self.assertEqual(cls.mac, polybytes(mac))
+        self.assertEqual(cls.extension, polybytes(ext + mac))
 
     def test_ntp_to_posix(self):
         f = self.target.ntp_to_posix
@@ -353,7 +431,7 @@ class TestSyncPacket(unittest.TestCase):
               "\x00\x00\x00\x03\x00\x00\x00\x20" \
               "In the end, our choices make us." \
               "\x11\x22\x33\x44"
-        pkt = data + ext
+        pkt = polybytes(data + ext)
         cls = self.target(pkt)
         self.assertEqual(cls.flatten(), pkt)
 
@@ -475,22 +553,39 @@ class TestMisc(unittest.TestCase):
         # Test sortaddr, ipv6
         cls.addr = "[11:22:33::44:55]:42"
         self.assertEqual(cls.sortaddr(),
-                         "\x00\x11\x00\x22\x00\x33\x00\x00"
-                         "\x00\x00\x00\x00\x00\x44\x00\x55")
+                         polybytes("\x00\x11\x00\x22\x00\x33\x00\x00"
+                                   "\x00\x00\x00\x00\x00\x44\x00\x55"))
         # Test sortaddr, ipv6, local
         cls.addr = "[11:22:33::44:55%8]:42"
         self.assertEqual(cls.sortaddr(),
-                         "\x00\x11\x00\x22\x00\x33\x00\x00"
-                         "\x00\x00\x00\x00\x00\x44\x00\x55")
+                         polybytes("\x00\x11\x00\x22\x00\x33\x00\x00"
+                                   "\x00\x00\x00\x00\x00\x44\x00\x55"))
         # Test sortaddr, ipv4
         cls.addr = "11.22.33.44:23"
-        self.assertEqual(cls.sortaddr(), (("\0" * 16) + "\x0b\x16\x21\x2c"))
+        self.assertEqual(cls.sortaddr(),
+                         polybytes((("\0" * 16) + "\x0b\x16\x21\x2c")))
         # Test __repr__
-        self.assertEqual(cls.__repr__(),
-                         "<MRUEntry: "
-                         "'last': '0x00000200.00000000', "
-                         "'addr': '11.22.33.44:23', 'rs': None, 'mv': None, "
-                         "'first': '0x00000100.00000000', 'ct': 4>")
+        # Python dicts enumeration order changes with different versions
+        if sys.version_info[0] < 3:
+            # Python 2
+            self.assertEqual(cls.__repr__(),
+                             "<MRUEntry: "
+                             "'last': '0x00000200.00000000', "
+                             "'addr': '11.22.33.44:23', 'rs': None, "
+                             "'mv': None, 'first': '0x00000100.00000000', "
+                             "'ct': 4>")
+        elif sys.version_info[1] >= 6:  # Already know it is 3.something
+            # Python 3.6+, dicts enumerate in assignment order
+            self.assertEqual(cls.__repr__(),
+                             "<MRUEntry: 'addr': '11.22.33.44:23', "
+                             "'last': '0x00000200.00000000', "
+                             "'first': '0x00000100.00000000', 'ct': 4, "
+                             "'mv': None, 'rs': None>")
+            pass
+        else:
+            # Python 3.x < 3.6, dicts enumerate randomly
+            # I can not test randomness of this type
+            pass
 
     def test_MRUList(self):
         # Test init
@@ -521,7 +616,7 @@ class TestControlPacket(unittest.TestCase):
         self.assertEqual(cls.status, 0)
         self.assertEqual(cls.associd, 0)
         self.assertEqual(cls.offset, 0)
-        self.assertEqual(cls.extension, "")
+        self.assertEqual(cls.extension, polybytes(""))
         self.assertEqual(cls.count, 0)
 
     def test_is_response(self):
@@ -607,7 +702,7 @@ class TestControlPacket(unittest.TestCase):
         self.assertEqual(cls.offset, 32)
         self.assertEqual(cls.count, 16)
         # Test flatten
-        self.assertEqual(cls.flatten(), totaldata)
+        self.assertEqual(cls.flatten(), polybytes(totaldata))
         # Test send
         send_data = []
 
@@ -615,7 +710,7 @@ class TestControlPacket(unittest.TestCase):
             send_data.append(pkt)
         cls.session.sendpkt = send_jig
         cls.send()
-        self.assertEqual(send_data, [totaldata])
+        self.assertEqual(send_data, [polybytes(totaldata)])
 
 
 class TestControlSession(unittest.TestCase):
@@ -888,10 +983,10 @@ class TestControlSession(unittest.TestCase):
         cls.sock = sockjig
         cls.debug = 3
         # Test
-        res = cls.sendpkt("blahfoo")
+        res = cls.sendpkt(polybytes("blahfoo"))
         self.assertEqual(res, 0)
         self.assertEqual(logjig.data, ["Sending 8 octets.  seq=0\n"])
-        self.assertEqual(sockjig.data, ["blahfoo\x00"])
+        self.assertEqual(sockjig.data, [polybytes("blahfoo\x00")])
         # Test error
         logjig.__init__()
         sockjig.fail_send = 1
@@ -922,15 +1017,15 @@ class TestControlSession(unittest.TestCase):
                               "***Internal error! Data too large "
                               "(" + str(len(data)) + ")\n"])
             # Test no auth
-            result = cls.sendrequest(1, 2, "foo")
+            result = cls.sendrequest(1, 2, polybytes("foo"))
             self.assertEqual(result.sequence, 1)
-            self.assertEqual(result.extension, "foo\x00")
+            self.assertEqual(result.extension, polybytes("foo\x00"))
             # Test with auth
             cls.keyid = 1
             cls.passwd = "qwerty"
-            result = cls.sendrequest(1, 2, "foo", True)
+            result = cls.sendrequest(1, 2, polybytes("foo"), True)
             self.assertEqual(result.sequence, 2)
-            self.assertEqual(result.extension, "foo\x00mac")
+            self.assertEqual(result.extension, polybytes("foo\x00mac"))
             # Test with auth keyid / password failure
             cls.keyid = None
             try:
@@ -958,7 +1053,7 @@ class TestControlSession(unittest.TestCase):
             sockjig.return_data = [
                 "\x0E\x81\x00\x00\x00\x03\x00\x02\x00\x00\x00\x00"]
             cls.getresponse(1, 2, True)
-            self.assertEqual(cls.response, "")
+            self.assertEqual(cls.response, polybytes(""))
             # Test with data
             sockjig.return_data = [
                 "\x0E\xA1\x00\x01\x00\x02\x00\x03\x00\x00\x00\x09"
@@ -969,7 +1064,8 @@ class TestControlSession(unittest.TestCase):
                 "quux=1\x00\x00"]
             cls.sequence = 1
             cls.getresponse(1, 3, True)
-            self.assertEqual(cls.response, "foo=4223,blah=248,x=23,quux=1")
+            self.assertEqual(cls.response,
+                             polybytes("foo=4223,blah=248,x=23,quux=1"))
             # Test MAXFRAGS bail
             maxtemp = ntpp.MAXFRAGS
             ntpp.MAXFRAGS = 1
@@ -1220,7 +1316,7 @@ class TestControlSession(unittest.TestCase):
                                     42, "", False)])
         # Test normal
         queries = []
-        cls.response = "\xDE\xAD\xF0\x0D"
+        cls.response = polybytes("\xDE\xAD\xF0\x0D")
         idlist = cls.readstat()
         self.assertEqual(len(idlist), 1)
         self.assertEqual(isinstance(idlist[0], ntpp.Peer), True)
@@ -1345,24 +1441,26 @@ class TestControlSession(unittest.TestCase):
         # Init
         cls = self.target()
         cls.doquery = doquery_jig
-        cls.response = "Config Succeeded    \n \x00 blah blah"
+        cls.response = polybytes("Config Succeeded    \n \x00 blah blah")
         # Test success
-        result = cls.config("Boo!")
+        result = cls.config(polybytes("Boo!"))
         self.assertEqual(result, True)
         self.assertEqual(queries,
-                         [(ntp.control.CTL_OP_CONFIGURE, 0, "Boo!", True)])
+                         [(ntp.control.CTL_OP_CONFIGURE, 0,
+                           polybytes("Boo!"), True)])
         # Test failure
         queries = []
-        cls.response = "whatever man..."
-        result = cls.config("Boo!")
+        cls.response = polybytes("whatever man...")
+        result = cls.config(polybytes("Boo!"))
         self.assertEqual(result, False)
         self.assertEqual(queries,
-                         [(ntp.control.CTL_OP_CONFIGURE, 0, "Boo!", True)])
+                         [(ntp.control.CTL_OP_CONFIGURE, 0,
+                           polybytes("Boo!"), True)])
         # Test no response
         queries = []
         cls.response = ""
         try:
-            cls.config("blah")
+            cls.config(polybytes("blah"))
             errored = False
         except ctlerr as e:
             errored = e.message
@@ -1378,7 +1476,7 @@ class TestControlSession(unittest.TestCase):
         cls = self.target()
         cls.doquery = doquery_jig
         # Test success
-        cls.response = "nonce=blah blah  "
+        cls.response = polybytes("nonce=blah blah  ")
         result = cls.fetch_nonce()
         self.assertEqual(result, "nonce=blah blah")
         self.assertEqual(queries,
@@ -1386,7 +1484,7 @@ class TestControlSession(unittest.TestCase):
         # Test failure
         queries = []
         cls.logfp = filefp
-        cls.response = "blah blah"
+        cls.response = polybytes("blah blah")
         try:
             result = cls.fetch_nonce()
             errored = False
@@ -1789,10 +1887,10 @@ class TestAuthenticator(unittest.TestCase):
             fakehashlibmod = jigs.HashlibModuleJig()
             ntpp.hashlib = fakehashlibmod
             # Test no digest
-            self.assertEqual(f(None, None, None, None), None)
+            self.assertEqual(f("", 0, None, polybytes("")), None)
             # Test with digest
             self.assertEqual(f("foo", 0x42, "bar", "quux"),
-                             "\x00\x00\x00\x42blahblahblahblah")
+                             polybytes("\x00\x00\x00\x42blahblahblahblah"))
         finally:
             ntpp.hashlib = temphash
 
@@ -1815,9 +1913,9 @@ class TestAuthenticator(unittest.TestCase):
             fakehashlibmod = jigs.HashlibModuleJig()
             ntpp.hashlib = fakehashlibmod
             # Test good
-            self.assertEqual(cls.verify_mac(good_pkt), True)
+            self.assertEqual(cls.verify_mac(polybytes(good_pkt)), True)
             # Test bad
-            self.assertEqual(cls.verify_mac(bad_pkt), False)
+            self.assertEqual(cls.verify_mac(polybytes(bad_pkt)), False)
         finally:
             ntpp.hashlib = temphash
 



View it on GitLab: https://gitlab.com/NTPsec/ntpsec/commit/9aa67300d5b00300f2062d7be6953942020d4cb3

---
View it on GitLab: https://gitlab.com/NTPsec/ntpsec/commit/9aa67300d5b00300f2062d7be6953942020d4cb3
You're receiving this email because of your account on gitlab.com.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <https://lists.ntpsec.org/pipermail/vc/attachments/20170921/8394c9ce/attachment.html>


More information about the vc mailing list