(original) (raw)

diff -r 78a2d1169be1 Doc/library/socket.rst --- a/Doc/library/socket.rst Tue Apr 14 11:21:26 2015 -0700 +++ b/Doc/library/socket.rst Tue Apr 14 15:34:33 2015 -0400 @@ -782,6 +782,16 @@ .. versionadded:: 3.3 +.. function:: get_socket_type(sock) + + Return a SocketKind object with the type of the sock + paramater. + This function removes the :const:`SOCK_NONBLOCK` and + :const:`SOCK_CLOEXEC` flags if present to return a valid + SocketKind object. + + .. versionadded:: 3.5 + .. _socket-objects: diff -r 78a2d1169be1 Lib/socket.py --- a/Lib/socket.py Tue Apr 14 11:21:26 2015 -0700 +++ b/Lib/socket.py Tue Apr 14 15:34:33 2015 -0400 @@ -731,3 +731,26 @@ _intenum_converter(socktype, SocketKind), proto, canonname, sa)) return addrlist + +def get_socket_type(sock): + """Retrieves the socket type of sock. + + Return a SocketKind object with the socket type, i.e: + SOCK_STREAM + SOCK_DGRAM + SOCK_RAW + Removing the SOCK_NONBLOCK and/or SOCK_CLOEXEC flags. + + """ + sock_type = sock.type + if not isinstance(sock_type, SocketKind): + try: + flags_to_remove = [SocketKind.SOCK_CLOEXEC, SocketKind.SOCK_NONBLOCK] + except AttributeError: + # SOCK_NONBLOCK and SOCK_CLOEXEC don't exist on this OS + return sock_type + for flag in flags_to_remove: + if sock_type & int(flag): + sock_type = sock_type ^ int(flag) + sock_type = _intenum_converter(sock_type, SocketKind) + return sock_type diff -r 78a2d1169be1 Lib/test/test_socket.py --- a/Lib/test/test_socket.py Tue Apr 14 11:21:26 2015 -0700 +++ b/Lib/test/test_socket.py Tue Apr 14 15:34:33 2015 -0400 @@ -1434,6 +1434,52 @@ self.assertEqual(s.family, 42424) self.assertEqual(s.type, 13331) + def test_get_socket_type(self): + with socket.socket(type=socket.SOCK_STREAM) as s: + self.assertEqual(socket.get_socket_type(s), + socket.SocketKind(socket.SOCK_STREAM)) + + def test_get_socket_type_file_type(self): + fd, _ = tempfile.mkstemp() + with socket.socket(type=13331, fileno=fd) as s: + self.assertEqual(socket.get_socket_type(s), 13331) + + def test_get_socket_type_set_timeout(self): + with socket.socket(type=socket.SOCK_STREAM) as s: + s.settimeout(2) + self.assertEqual(socket.get_socket_type(s), + socket.SocketKind(socket.SOCK_STREAM)) + + @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'), + 'test needs socket.SOCK_NONBLOCK') + @support.requires_linux_version(2, 6, 28) + def test_get_socket_type_non_block_flag(self): + type_flags = socket.SOCK_STREAM | socket.SOCK_NONBLOCK + with socket.socket(type=type_flags) as s: + self.assertEqual(socket.get_socket_type(s), + socket.SocketKind(socket.SOCK_STREAM)) + + @unittest.skipUnless(hasattr(socket, 'SOCK_CLOEXEC'), + 'test needs socket.SOCK_CLOEXEC') + @support.requires_linux_version(2, 6, 28) + def test_get_socket_type_cloexec_flag(self): + type_flags = socket.SOCK_DGRAM | socket.SOCK_CLOEXEC + with socket.socket(type=type_flags) as s: + self.assertEqual(socket.get_socket_type(s), + socket.SocketKind(socket.SOCK_DGRAM)) + + @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'), + 'test needs socket.SOCK_NONBLOCK') + @unittest.skipUnless(hasattr(socket, 'SOCK_CLOEXEC'), + 'test needs socket.SOCK_CLOEXEC') + @support.requires_linux_version(2, 6, 28) + def test_get_socket_type_cloexec_non_block_flag(self): + type_flags = socket.SOCK_DGRAM | socket.SOCK_CLOEXEC | socket.SOCK_NONBLOCK + with socket.socket(type=type_flags) as s: + self.assertEqual(socket.get_socket_type(s), + socket.SocketKind(socket.SOCK_DGRAM)) + + @unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.') class BasicCANTest(unittest.TestCase):