#!/usr/bin/python3
#
# Copyright 2015 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
from errno import *  # pylint: disable=wildcard-import
import binascii
import os
import random
import select
from socket import *  # pylint: disable=wildcard-import
import struct
import threading
import time
import unittest

import cstruct
import multinetwork_base
import net_test
import packets
import sock_diag
import tcp_test

# Mostly empty structure definition containing only the fields we currently use.
TcpInfo = cstruct.Struct("TcpInfo", "64xI", "tcpi_rcv_ssthresh")

NUM_SOCKETS = 30
NO_BYTECODE = b""

IPPROTO_SCTP = 132

def HaveSctp():
  try:
    s = socket(AF_INET, SOCK_STREAM, IPPROTO_SCTP)
    s.close()
    return True
  except IOError:
    return False

HAVE_SCTP = HaveSctp()


class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest):
  """Basic tests for SOCK_DIAG functionality.

    Relevant kernel commits:
      android-3.4:
        ab4a727 net: inet_diag: zero out uninitialized idiag_{src,dst} fields
        99ee451 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()

      android-3.10:
        3eb409b net: inet_diag: zero out uninitialized idiag_{src,dst} fields
        f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()

      android-3.18:
        e603010 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()

      android-4.4:
        525ee59 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
  """
  @staticmethod
  def _CreateLotsOfSockets(socktype):
    # Dict mapping (addr, sport, dport) tuples to socketpairs.
    socketpairs = {}
    for _ in range(NUM_SOCKETS):
      family, addr = random.choice([
          (AF_INET, "127.0.0.1"),
          (AF_INET6, "::1"),
          (AF_INET6, "::ffff:127.0.0.1")])
      socketpair = net_test.CreateSocketPair(family, socktype, addr)
      sport, dport = (socketpair[0].getsockname()[1],
                      socketpair[1].getsockname()[1])
      socketpairs[(addr, sport, dport)] = socketpair
    return socketpairs

  def assertSocketClosed(self, sock):
    self.assertRaisesErrno(ENOTCONN, sock.getpeername)

  def assertSocketConnected(self, sock):
    sock.getpeername()  # No errors? Socket is alive and connected.

  def assertSocketsClosed(self, socketpair):
    for sock in socketpair:
      self.assertSocketClosed(sock)

  def assertMarkIs(self, mark, attrs):
    self.assertEqual(mark, attrs.get("INET_DIAG_MARK", None))

  def assertSockInfoMatchesSocket(self, s, info):
    diag_msg, attrs = info
    family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
    self.assertEqual(diag_msg.family, family)

    src, sport = s.getsockname()[0:2]
    self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src))
    self.assertEqual(diag_msg.id.sport, sport)

    if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]:
      dst, dport = s.getpeername()[0:2]
      self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst))
      self.assertEqual(diag_msg.id.dport, dport)
    else:
      self.assertRaisesErrno(ENOTCONN, s.getpeername)

    mark = s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
    self.assertMarkIs(mark, attrs)

  def PackAndCheckBytecode(self, instructions):
    bytecode = self.sock_diag.PackBytecode(instructions)
    decoded = self.sock_diag.DecodeBytecode(bytecode)
    self.assertEqual(len(instructions), len(decoded))
    self.assertFalse("???" in decoded)
    return bytecode

  def _EventDuringBlockingCall(self, sock, call, expected_errno, event):
    """Simulates an external event during a blocking call on sock.

    Args:
      sock: The socket to use.
      call: A function, the call to make. Takes one parameter, sock.
      expected_errno: The value that call is expected to fail with, or None if
        call is expected to succeed.
      event: A function, the event that will happen during the blocking call.
        Takes one parameter, sock.
    """
    thread = SocketExceptionThread(sock, call)
    thread.start()
    time.sleep(0.1)
    event(sock)
    thread.join(1)
    self.assertFalse(thread.is_alive())
    if expected_errno is not None:
      self.assertIsNotNone(thread.exception)
      self.assertTrue(isinstance(thread.exception, IOError),
                      "Expected IOError, got %s" % thread.exception)
      self.assertEqual(expected_errno, thread.exception.errno)
    else:
      self.assertIsNone(thread.exception)
    self.assertSocketClosed(sock)

  def CloseDuringBlockingCall(self, sock, call, expected_errno):
    self._EventDuringBlockingCall(
        sock, call, expected_errno,
        lambda sock: self.sock_diag.CloseSocketFromFd(sock))

  def setUp(self):
    super(SockDiagBaseTest, self).setUp()
    self.sock_diag = sock_diag.SockDiag()
    self.socketpairs = {}

  def tearDown(self):
    for socketpair in list(self.socketpairs.values()):
      for s in socketpair:
        s.close()
    super(SockDiagBaseTest, self).tearDown()


class SockDiagTest(SockDiagBaseTest):

  def testFindsMappedSockets(self):
    """Tests that inet_diag_find_one_icsk can find mapped sockets."""
    socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
                                           "::ffff:127.0.0.1")
    for sock in socketpair:
      diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
      diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
      self.sock_diag.GetSockInfo(diag_req)
      # No errors? Good.

    for sock in socketpair:
      sock.close()

  def CheckFindsAllMySockets(self, socktype, proto):
    """Tests that basic socket dumping works."""
    self.socketpairs = self._CreateLotsOfSockets(socktype)
    sockets = self.sock_diag.DumpAllInetSockets(proto, NO_BYTECODE)
    self.assertGreaterEqual(len(sockets), NUM_SOCKETS)

    # Find the cookies for all of our sockets.
    cookies = {}
    for diag_msg, unused_attrs in sockets:
      addr = self.sock_diag.GetSourceAddress(diag_msg)
      sport = diag_msg.id.sport
      dport = diag_msg.id.dport
      if (addr, sport, dport) in self.socketpairs:
        cookies[(addr, sport, dport)] = diag_msg.id.cookie
      elif (addr, dport, sport) in self.socketpairs:
        cookies[(addr, sport, dport)] = diag_msg.id.cookie

    # Did we find all the cookies?
    self.assertEqual(2 * NUM_SOCKETS, len(cookies))

    socketpairs = list(self.socketpairs.values())
    random.shuffle(socketpairs)
    for socketpair in socketpairs:
      for sock in socketpair:
        # Check that we can find a diag_msg by scanning a dump.
        self.assertSockInfoMatchesSocket(
            sock,
            self.sock_diag.FindSockInfoFromFd(sock))
        cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie

        # Check that we can find a diag_msg once we know the cookie.
        req = self.sock_diag.DiagReqFromSocket(sock)
        req.id.cookie = cookie
        if proto == IPPROTO_UDP:
          # Kernel bug: for UDP sockets, the order of arguments must be swapped.
          # See testDemonstrateUdpGetSockIdBug.
          req.id.sport, req.id.dport = req.id.dport, req.id.sport
          req.id.src, req.id.dst = req.id.dst, req.id.src
        info = self.sock_diag.GetSockInfo(req)
        self.assertSockInfoMatchesSocket(sock, info)

    for socketpair in socketpairs:
      for sock in socketpair:
        sock.close()

  def assertItemsEqual(self, expected, actual):
    try:
      super(SockDiagTest, self).assertItemsEqual(expected, actual)
    except AttributeError:
      # This was renamed in python3 but has the same behaviour.
      super(SockDiagTest, self).assertCountEqual(expected, actual)

  def testFindsAllMySocketsTcp(self):
    self.CheckFindsAllMySockets(SOCK_STREAM, IPPROTO_TCP)

  def testFindsAllMySocketsUdp(self):
    self.CheckFindsAllMySockets(SOCK_DGRAM, IPPROTO_UDP)

  def testBytecodeCompilation(self):
    # pylint: disable=bad-whitespace
    instructions = [
        (sock_diag.INET_DIAG_BC_S_GE,   1, 8, 0),                      # 0
        (sock_diag.INET_DIAG_BC_D_LE,   1, 7, 0xffff),                 # 8
        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)),       # 16
        (sock_diag.INET_DIAG_BC_JMP,    1, 3, None),                   # 44
        (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)),  # 48
        (sock_diag.INET_DIAG_BC_D_LE,   1, 3, 0x6665),  # not used     # 64
        (sock_diag.INET_DIAG_BC_NOP,    1, 1, None),                   # 72
                                                                       # 76 acc
                                                                       # 80 rej
    ]
    # pylint: enable=bad-whitespace
    bytecode = self.PackAndCheckBytecode(instructions)
    expected = (
        b"0208500000000000"
        b"050848000000ffff"
        b"071c20000a800000ffffffff00000000000000000000000000000001"
        b"01041c00"
        b"0718200002200000ffffffff7f000001"
        b"0508100000006566"
        b"00040400"
    )
    states = 1 << tcp_test.TCP_ESTABLISHED
    self.assertEqual(expected, binascii.hexlify(bytecode))
    self.assertEqual(76, len(bytecode))
    self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
    filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode,
                                                        states=states)
    allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE,
                                                   states=states)
    self.assertItemsEqual(allsockets, filteredsockets)

    # Pick a few sockets in hash table order, and check that the bytecode we
    # compiled selects them properly.
    for socketpair in list(self.socketpairs.values())[:20]:
      for s in socketpair:
        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
        instructions = [
            (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport),
            (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport),
            (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport),
            (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport),
        ]
        bytecode = self.PackAndCheckBytecode(instructions)
        self.assertEqual(32, len(bytecode))
        sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
        self.assertEqual(1, len(sockets))

        # TODO: why doesn't comparing the cstructs work?
        self.assertEqual(diag_msg.Pack(), sockets[0][0].Pack())

  def testCrossFamilyBytecode(self):
    """Checks for a cross-family bug in inet_diag_hostcond matching.

    Relevant kernel commits:
      android-3.4:
        f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run()
    """
    # TODO: this is only here because the test fails if there are any open
    # sockets other than the ones it creates itself. Make the bytecode more
    # specific and remove it.
    states = 1 << tcp_test.TCP_ESTABLISHED
    self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE,
                                                       states=states))

    unused_pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1")
    unused_pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1")

    bytecode4 = self.PackAndCheckBytecode([
        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))])
    bytecode6 = self.PackAndCheckBytecode([
        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))])

    # IPv4/v6 filters must never match IPv6/IPv4 sockets...
    v4socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4,
                                                  states=states)
    self.assertTrue(v4socks)
    self.assertTrue(all(d.family == AF_INET for d, _ in v4socks))

    v6socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6,
                                                  states=states)
    self.assertTrue(v6socks)
    self.assertTrue(all(d.family == AF_INET6 for d, _ in v6socks))

    # Except for mapped addresses, which match both IPv4 and IPv6.
    pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
                                      "::ffff:127.0.0.1")
    diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5]
    v4socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
                                                               bytecode4,
                                                               states=states)]
    v6socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
                                                               bytecode6,
                                                               states=states)]
    self.assertTrue(all(d in v4socks for d in diag_msgs))
    self.assertTrue(all(d in v6socks for d in diag_msgs))

    for sock in unused_pair4:
      sock.close()

    for sock in unused_pair6:
      sock.close()

    for sock in pair5:
      sock.close()

  def testPortComparisonValidation(self):
    """Checks for a bug in validating port comparison bytecode.

    Relevant kernel commits:
      android-3.4:
        5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads
    """
    bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8))
    self.assertEqual("???",
                      self.sock_diag.DecodeBytecode(bytecode))
    self.assertRaisesErrno(
        EINVAL,
        self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack())

  def testNonSockDiagCommand(self):
    def DiagDump(code):
      sock_id = self.sock_diag._EmptyInetDiagSockId()
      req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
                                     sock_id))
      self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg)

    op = sock_diag.SOCK_DIAG_BY_FAMILY
    DiagDump(op)  # No errors? Good.
    self.assertRaisesErrno(EINVAL, DiagDump, op + 17)

  def CheckSocketCookie(self, inet, addr):
    """Tests that getsockopt SO_COOKIE can get cookie for all sockets."""
    socketpair = net_test.CreateSocketPair(inet, SOCK_STREAM, addr)
    for sock in socketpair:
      diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
      cookie = sock.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8)
      self.assertEqual(diag_msg.id.cookie, cookie)

    for sock in socketpair:
      sock.close()

  def testGetsockoptcookie(self):
    self.CheckSocketCookie(AF_INET, "127.0.0.1")
    self.CheckSocketCookie(AF_INET6, "::1")

  def testDemonstrateUdpGetSockIdBug(self):
    # TODO: this is because udp_dump_one mistakenly uses __udp[46]_lib_lookup
    # by passing the source address as the source address argument.
    # Unfortunately those functions are intended to match local sockets based
    # on received packets, and the argument that ends up being compared with
    # e.g., sk_daddr is actually saddr, not daddr. udp_diag_destroy does not
    # have this bug.  Upstream has confirmed that this will not be fixed:
    # https://www.mail-archive.com/netdev@vger.kernel.org/msg248638.html
    """Documents a bug: getting UDP sockets requires swapping src and dst."""
    for version in [4, 5, 6]:
      family = net_test.GetAddressFamily(version)
      s = socket(family, SOCK_DGRAM, 0)
      self.SelectInterface(s, self.RandomNetid(), "mark")
      s.connect((self.GetRemoteSocketAddress(version), 53))

      # Create a fully-specified diag req from our socket, including cookie if
      # we can get it.
      req = self.sock_diag.DiagReqFromSocket(s)
      req.id.cookie = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8)

      # As is, this request does not find anything.
      with self.assertRaisesErrno(ENOENT):
        self.sock_diag.GetSockInfo(req)

      # But if we swap src and dst, the kernel finds our socket.
      req.id.sport, req.id.dport = req.id.dport, req.id.sport
      req.id.src, req.id.dst = req.id.dst, req.id.src

      self.assertSockInfoMatchesSocket(s, self.sock_diag.GetSockInfo(req))

      s.close()


class SockDestroyTest(SockDiagBaseTest):
  """Tests that SOCK_DESTROY works correctly.

  Relevant kernel commits:
    net-next:
      b613f56 net: diag: split inet_diag_dump_one_icsk into two
      64be0ae net: diag: Add the ability to destroy a socket.
      6eb5d2e net: diag: Support SOCK_DESTROY for inet sockets.
      c1e64e2 net: diag: Support destroying TCP sockets.
      2010b93 net: tcp: deal with listen sockets properly in tcp_abort.

    android-3.4:
      d48ec88 net: diag: split inet_diag_dump_one_icsk into two
      2438189 net: diag: Add the ability to destroy a socket.
      7a2ddbc net: diag: Support SOCK_DESTROY for inet sockets.
      44047b2 net: diag: Support destroying TCP sockets.
      200dae7 net: tcp: deal with listen sockets properly in tcp_abort.

    android-3.10:
      9eaff90 net: diag: split inet_diag_dump_one_icsk into two
      d60326c net: diag: Add the ability to destroy a socket.
      3d4ce85 net: diag: Support SOCK_DESTROY for inet sockets.
      529dfc6 net: diag: Support destroying TCP sockets.
      9c712fe net: tcp: deal with listen sockets properly in tcp_abort.

    android-3.18:
      100263d net: diag: split inet_diag_dump_one_icsk into two
      194c5f3 net: diag: Add the ability to destroy a socket.
      8387ea2 net: diag: Support SOCK_DESTROY for inet sockets.
      b80585a net: diag: Support destroying TCP sockets.
      476c6ce net: tcp: deal with listen sockets properly in tcp_abort.

    android-4.1:
      56eebf8 net: diag: split inet_diag_dump_one_icsk into two
      fb486c9 net: diag: Add the ability to destroy a socket.
      0c02b7e net: diag: Support SOCK_DESTROY for inet sockets.
      67c71d8 net: diag: Support destroying TCP sockets.
      a76e0ec net: tcp: deal with listen sockets properly in tcp_abort.
      e6e277b net: diag: support v4mapped sockets in inet_diag_find_one_icsk()

    android-4.4:
      76c83a9 net: diag: split inet_diag_dump_one_icsk into two
      f7cf791 net: diag: Add the ability to destroy a socket.
      1c42248 net: diag: Support SOCK_DESTROY for inet sockets.
      c9e8440d net: diag: Support destroying TCP sockets.
      3d9502c tcp: diag: add support for request sockets to tcp_abort()
      001cf75 net: tcp: deal with listen sockets properly in tcp_abort.
  """

  def testClosesSockets(self):
    self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
    for _, socketpair in self.socketpairs.items():
      # Close one of the sockets.
      # This will send a RST that will close the other side as well.
      s = random.choice(socketpair)
      if random.randrange(0, 2) == 1:
        self.sock_diag.CloseSocketFromFd(s)
      else:
        diag_msg = self.sock_diag.FindSockDiagFromFd(s)

        # Get the cookie wrong and ensure that we get an error and the socket
        # is not closed.
        real_cookie = diag_msg.id.cookie
        diag_msg.id.cookie = os.urandom(len(real_cookie))
        req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
        self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req)
        self.assertSocketConnected(s)

        # Now close it with the correct cookie.
        req.id.cookie = real_cookie
        self.sock_diag.CloseSocket(req)

      # Check that both sockets in the pair are closed.
      self.assertSocketsClosed(socketpair)

  # TODO:
  # Test that killing unix sockets returns EOPNOTSUPP.


class SocketExceptionThread(threading.Thread):

  def __init__(self, sock, operation):
    self.exception = None
    super(SocketExceptionThread, self).__init__()
    self.daemon = True
    self.sock = sock
    self.operation = operation

  def run(self):
    try:
      self.operation(self.sock)
    except (IOError, AssertionError) as e:
      self.exception = e


class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):

  def testIpv4MappedSynRecvSocket(self):
    """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.

    Relevant kernel commits:
         android-3.4:
           457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
    """
    netid = random.choice(list(self.tuns.keys()))
    self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid)
    sock_id = self.sock_diag._EmptyInetDiagSockId()
    sock_id.sport = self.port
    states = 1 << tcp_test.TCP_SYN_RECV
    req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
    children = self.sock_diag.Dump(req, NO_BYTECODE)

    self.assertTrue(children)
    for child, unused_args in children:
      self.assertEqual(tcp_test.TCP_SYN_RECV, child.state)
      self.assertEqual(self.sock_diag.PaddedAddress(self.remotesockaddr),
                       child.id.dst)
      self.assertEqual(self.sock_diag.PaddedAddress(self.mysockaddr),
                       child.id.src)


class TcpRcvWindowTest(tcp_test.TcpBaseTest, SockDiagBaseTest):

  RWND_SIZE = 64000
  TCP_DEFAULT_INIT_RWND = "/proc/sys/net/ipv4/tcp_default_init_rwnd"

  def setUp(self):
    super(TcpRcvWindowTest, self).setUp()
    self.assertRaisesErrno(ENOENT, open, self.TCP_DEFAULT_INIT_RWND, "w")

  def checkInitRwndSize(self, version, netid):
    self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, netid)
    tcpInfo = TcpInfo(self.accepted.getsockopt(net_test.SOL_TCP,
                                               net_test.TCP_INFO, len(TcpInfo)))
    self.assertLess(self.RWND_SIZE, tcpInfo.tcpi_rcv_ssthresh,
                    "Tcp rwnd of netid=%d, version=%d is not enough. "
                    "Expect: %d, actual: %d" % (netid, version, self.RWND_SIZE,
                                                tcpInfo.tcpi_rcv_ssthresh))
    self.CloseSockets()

  def checkSynPacketWindowSize(self, version, netid):
    s = self.BuildSocket(version, net_test.TCPSocket, netid, "mark")
    myaddr = self.MyAddress(version, netid)
    dstaddr = self.GetRemoteAddress(version)
    dstsockaddr = self.GetRemoteSocketAddress(version)
    desc, expected = packets.SYN(53, version, myaddr, dstaddr,
                                 sport=None, seq=None)
    self.assertRaisesErrno(EINPROGRESS, s.connect, (dstsockaddr, 53))
    msg = "IPv%s TCP connect: expected %s on %s" % (
        version, desc, self.GetInterfaceName(netid))
    syn = self.ExpectPacketOn(netid, msg, expected)
    self.assertLess(self.RWND_SIZE, syn.window)
    s.close()

  def testTcpCwndSize(self):
    for version in [4, 5, 6]:
      for netid in self.NETIDS:
        self.checkInitRwndSize(version, netid)
        self.checkSynPacketWindowSize(version, netid)


class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):

  def setUp(self):
    super(SockDestroyTcpTest, self).setUp()
    self.netid = random.choice(list(self.tuns.keys()))

  def ExpectRst(self, msg):
    desc, rst = self.RstPacket()
    msg = "%s: expecting %s: " % (msg, desc)
    self.ExpectPacketOn(self.netid, msg, rst)

  def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
    """Closes the socket and checks whether a RST is sent or not."""
    if sock is not None:
      self.assertIsNone(req, "Must specify sock or req, not both")
      self.sock_diag.CloseSocketFromFd(sock)
      self.assertRaisesErrno(EINVAL, sock.accept)
    else:
      self.assertIsNone(sock, "Must specify sock or req, not both")
      self.sock_diag.CloseSocket(req)

    if expect_reset:
      self.ExpectRst(msg)
    else:
      msg = "%s: " % msg
      self.ExpectNoPacketsOn(self.netid, msg)

    if sock is not None and do_close:
      sock.close()

  def CheckTcpReset(self, state, statename):
    for version in [4, 5, 6]:
      msg = "Closing incoming IPv%d %s socket" % (version, statename)
      self.IncomingConnection(version, state, self.netid)
      self.CheckRstOnClose(self.s, None, False, msg)
      if state != tcp_test.TCP_LISTEN:
        msg = "Closing accepted IPv%d %s socket" % (version, statename)
        self.CheckRstOnClose(self.accepted, None, True, msg)
        self.CloseSockets()

  def testTcpResets(self):
    """Checks that closing sockets in appropriate states sends a RST."""
    self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN")
    self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED")
    self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")

  def testFinWait1Socket(self):
    for version in [4, 5, 6]:
      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)

      # Get the cookie so we can find this socket after we close it.
      diag_msg = self.sock_diag.FindSockDiagFromFd(self.accepted)
      diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)

      # Close the socket and check that it goes into FIN_WAIT1 and sends a FIN.
      net_test.EnableFinWait(self.accepted)
      self.accepted.close()
      self.accepted = None
      diag_req.states = 1 << tcp_test.TCP_FIN_WAIT1
      diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
      self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state)
      desc, fin = self.FinPacket()
      msg = "Closing FIN_WAIT1 socket"
      self.ExpectPacketOn(self.netid, msg, fin)

      # Destroy the socket.
      self.sock_diag.CloseSocketFromFd(self.s)
      self.assertRaisesErrno(EINVAL, self.s.accept)
      try:
        diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
      except Error as e:
        # Newer kernels will have closed the socket and sent a RST.
        self.assertEqual(ENOENT, e.errno)
        self.ExpectRst(msg)
        self.CloseSockets()
        return

      # Older kernels don't support closing FIN_WAIT1 sockets.
      # Check that no RST is sent and that the socket is still in FIN_WAIT1, and
      # advances to FIN_WAIT2 if the FIN is ACked.
      msg = "%s: " % msg
      self.ExpectNoPacketsOn(self.netid, msg)
      self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state)

      # ACK the FIN so we don't trip over retransmits in future tests.
      finversion = 4 if version == 5 else version
      desc, finack = packets.ACK(finversion, self.remoteaddr, self.myaddr, fin)
      diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
      self.ReceivePacketOn(self.netid, finack)

      # See if we can find the resulting FIN_WAIT2 socket.
      diag_req.states = 1 << tcp_test.TCP_FIN_WAIT2
      infos = self.sock_diag.Dump(diag_req, NO_BYTECODE)
      self.assertTrue(any(diag_msg.state == tcp_test.TCP_FIN_WAIT2
                          for diag_msg, attrs in infos),
                      "Expected to find FIN_WAIT2 socket in %s" % infos)

      self.CloseSockets()

  def FindChildSockets(self, s):
    """Finds the SYN_RECV child sockets of a given listening socket."""
    d = self.sock_diag.FindSockDiagFromFd(self.s)
    req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
    req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED
    req.id.cookie = b"\x00" * 8

    bad_bytecode = self.PackAndCheckBytecode(
        [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (0xffff, 0xffff))])
    self.assertEqual([], self.sock_diag.Dump(req, bad_bytecode))

    bytecode = self.PackAndCheckBytecode(
        [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (self.netid, 0xffff))])
    children = self.sock_diag.Dump(req, bytecode)
    return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
            for d, _ in children]

  def CheckChildSocket(self, version, statename, parent_first):
    state = getattr(tcp_test, statename)

    self.IncomingConnection(version, state, self.netid)

    d = self.sock_diag.FindSockDiagFromFd(self.s)
    parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
    children = self.FindChildSockets(self.s)
    self.assertEqual(1, len(children))

    is_established = (state == tcp_test.TCP_NOT_YET_ACCEPTED)
    expected_state = tcp_test.TCP_ESTABLISHED if is_established else state

    for child in children:
      diag_msg, attrs = self.sock_diag.GetSockInfo(child)
      self.assertEqual(diag_msg.state, expected_state)
      self.assertMarkIs(self.netid, attrs)

    def CloseParent(expect_reset):
      msg = "Closing parent IPv%d %s socket %s child" % (
          version, statename, "before" if parent_first else "after")
      self.CheckRstOnClose(self.s, None, expect_reset, msg)
      self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, parent)

    def CheckChildrenClosed():
      for child in children:
        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)

    def CloseChildren():
      for child in children:
        msg = "Closing child IPv%d %s socket %s parent" % (
            version, statename, "after" if parent_first else "before")
        self.sock_diag.GetSockInfo(child)
        self.CheckRstOnClose(None, child, is_established, msg)
        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
      CheckChildrenClosed()

    if parent_first:
      # Closing the parent will close child sockets, which will send a RST,
      # iff they are already established.
      CloseParent(is_established)
      if is_established:
        CheckChildrenClosed()
      else:
        CloseChildren()
        CheckChildrenClosed()
    else:
      CloseChildren()
      CloseParent(False)

    self.CloseSockets()

  def testChildSockets(self):
    for version in [4, 5, 6]:
      self.CheckChildSocket(version, "TCP_SYN_RECV", False)
      self.CheckChildSocket(version, "TCP_SYN_RECV", True)
      self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False)
      self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True)

  def testAcceptInterrupted(self):
    """Tests that accept() is interrupted by SOCK_DESTROY."""
    for version in [4, 5, 6]:
      self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid)
      self.assertRaisesErrno(ENOTCONN, self.s.recv, 4096)
      self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
      self.assertRaisesErrno(ECONNABORTED, self.s.send, b"foo")
      self.assertRaisesErrno(EINVAL, self.s.accept)
      # TODO: this should really return an error such as ENOTCONN...
      self.assertEqual(b"", self.s.recv(4096))
      self.CloseSockets()

  def testReadInterrupted(self):
    """Tests that read() is interrupted by SOCK_DESTROY."""
    for version in [4, 5, 6]:
      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
      self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
                                   ECONNABORTED)
      # Writing returns EPIPE, and reading returns EOF.
      self.assertRaisesErrno(EPIPE, self.accepted.send, b"foo")
      self.assertEqual(b"", self.accepted.recv(4096))
      self.assertEqual(b"", self.accepted.recv(4096))
      self.CloseSockets()

  def testConnectInterrupted(self):
    """Tests that connect() is interrupted by SOCK_DESTROY."""
    for version in [4, 5, 6]:
      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
      s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
      self.SelectInterface(s, self.netid, "mark")

      remotesockaddr = self.GetRemoteSocketAddress(version)
      remoteaddr = self.GetRemoteAddress(version)
      s.bind(("", 0))
      _, sport = s.getsockname()[:2]
      self.CloseDuringBlockingCall(
          s, lambda sock: sock.connect((remotesockaddr, 53)), ECONNABORTED)
      desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid),
                              remoteaddr, sport=sport, seq=None)
      self.ExpectPacketOn(self.netid, desc, syn)
      msg = "SOCK_DESTROY of socket in connect, expected no RST"
      self.ExpectNoPacketsOn(self.netid, msg)
      s.close()


class PollOnCloseTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
  """Tests that the effect of SOCK_DESTROY on poll matches TCP RSTs.

  The behaviour of poll() in these cases is not what we might expect: if only
  POLLIN is specified, it will return POLLIN|POLLERR|POLLHUP, but if POLLOUT
  is (also) specified, it will only return POLLOUT.
  """

  POLLIN_OUT = select.POLLIN | select.POLLOUT
  POLLIN_ERR_HUP = select.POLLIN | select.POLLERR | select.POLLHUP

  def setUp(self):
    super(PollOnCloseTest, self).setUp()
    self.netid = random.choice(list(self.tuns.keys()))

  POLL_FLAGS = [(select.POLLIN, "IN"), (select.POLLOUT, "OUT"),
                (select.POLLERR, "ERR"), (select.POLLHUP, "HUP")]

  def PollResultToString(self, poll_events, ignoremask):
    out = []
    for fd, event in poll_events:
      flags = [name for (flag, name) in self.POLL_FLAGS
               if event & flag & ~ignoremask != 0]
      out.append((fd, "|".join(flags)))
    return out

  def BlockingPoll(self, sock, mask, expected, ignoremask):
    p = select.poll()
    p.register(sock, mask)
    expected_fds = [(sock.fileno(), expected)]
    # Don't block forever or we'll hang continuous test runs on failure.
    # A 5-second timeout should be long enough not to be flaky.
    actual_fds = p.poll(5000)
    self.assertEqual(self.PollResultToString(expected_fds, ignoremask),
                     self.PollResultToString(actual_fds, ignoremask))

  def RstDuringBlockingCall(self, sock, call, expected_errno):
    self._EventDuringBlockingCall(
        sock, call, expected_errno,
        lambda _: self.ReceiveRstPacketOn(self.netid))

  def assertSocketErrors(self, errno):
    # The first operation returns the expected errno.
    self.assertRaisesErrno(errno, self.accepted.recv, 4096)

    # Subsequent operations behave as normal.
    self.assertRaisesErrno(EPIPE, self.accepted.send, b"foo")
    self.assertEqual(b"", self.accepted.recv(4096))
    self.assertEqual(b"", self.accepted.recv(4096))

  def CheckPollDestroy(self, mask, expected, ignoremask):
    """Interrupts a poll() with SOCK_DESTROY."""
    for version in [4, 5, 6]:
      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
      self.CloseDuringBlockingCall(
          self.accepted,
          lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
          None)
      self.assertSocketErrors(ECONNABORTED)
      self.CloseSockets()

  def CheckPollRst(self, mask, expected, ignoremask):
    """Interrupts a poll() by receiving a TCP RST."""
    for version in [4, 5, 6]:
      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
      self.RstDuringBlockingCall(
          self.accepted,
          lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
          None)
      self.assertSocketErrors(ECONNRESET)
      self.CloseSockets()

  def testReadPollRst(self):
    self.CheckPollRst(select.POLLIN, self.POLLIN_ERR_HUP, 0)

  def testWritePollRst(self):
    self.CheckPollRst(select.POLLOUT, select.POLLOUT, 0)

  def testReadWritePollRst(self):
    self.CheckPollRst(self.POLLIN_OUT, select.POLLOUT, 0)

  def testReadPollDestroy(self):
    # tcp_abort has the same race that tcp_reset has, but it's not fixed yet.
    ignoremask = select.POLLIN | select.POLLHUP
    self.CheckPollDestroy(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask)

  def testWritePollDestroy(self):
    self.CheckPollDestroy(select.POLLOUT, select.POLLOUT, 0)

  def testReadWritePollDestroy(self):
    self.CheckPollDestroy(self.POLLIN_OUT, select.POLLOUT, 0)


class SockDestroyUdpTest(SockDiagBaseTest):

  """Tests SOCK_DESTROY on UDP sockets.

    Relevant kernel commits:
      upstream net-next:
        5d77dca net: diag: support SOCK_DESTROY for UDP sockets
        f95bf34 net: diag: make udp_diag_destroy work for mapped addresses.
  """

  def testClosesUdpSockets(self):
    self.socketpairs = self._CreateLotsOfSockets(SOCK_DGRAM)
    for _, socketpair in self.socketpairs.items():
      s1, s2 = socketpair

      self.assertSocketConnected(s1)
      self.sock_diag.CloseSocketFromFd(s1)
      self.assertSocketClosed(s1)

      self.assertSocketConnected(s2)
      self.sock_diag.CloseSocketFromFd(s2)
      self.assertSocketClosed(s2)

  def BindToRandomPort(self, s, addr):
    ATTEMPTS = 20
    for i in range(20):
      port = random.randrange(1024, 65535)
      try:
        s.bind((addr, port))
        return port
      except error as e:
        if e.errno != EADDRINUSE:
          raise e
    raise ValueError("Could not find a free port on %s after %d attempts" %
                     (addr, ATTEMPTS))

  def testSocketAddressesAfterClose(self):
    for version in 4, 5, 6:
      netid = random.choice(self.NETIDS)
      dst = self.GetRemoteSocketAddress(version)
      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
      unspec = {4: "0.0.0.0", 5: "::", 6: "::"}[version]

      # Closing a socket that was not explicitly bound (i.e., bound via
      # connect(), not bind()) clears the source address and port.
      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
      self.SelectInterface(s, netid, "mark")
      s.connect((dst, 53))
      self.sock_diag.CloseSocketFromFd(s)
      self.assertEqual((unspec, 0), s.getsockname()[:2])
      s.close()

      # Closing a socket bound to an IP address leaves the address as is.
      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
      src = self.MySocketAddress(version, netid)
      s.bind((src, 0))
      s.connect((dst, 53))
      port = s.getsockname()[1]
      self.sock_diag.CloseSocketFromFd(s)
      self.assertEqual((src, 0), s.getsockname()[:2])
      s.close()

      # Closing a socket bound to a port leaves the port as is.
      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
      port = self.BindToRandomPort(s, "")
      s.connect((dst, 53))
      self.sock_diag.CloseSocketFromFd(s)
      self.assertEqual((unspec, port), s.getsockname()[:2])
      s.close()

      # Closing a socket bound to IP address and port leaves both as is.
      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
      src = self.MySocketAddress(version, netid)
      port = self.BindToRandomPort(s, src)
      self.sock_diag.CloseSocketFromFd(s)
      self.assertEqual((src, port), s.getsockname()[:2])
      s.close()

  def testReadInterrupted(self):
    """Tests that read() is interrupted by SOCK_DESTROY."""
    for version in [4, 5, 6]:
      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
      s = net_test.UDPSocket(family)
      self.SelectInterface(s, random.choice(self.NETIDS), "mark")
      addr = self.GetRemoteSocketAddress(version)

      # Check that reads on connected sockets are interrupted.
      s.connect((addr, 53))
      self.assertEqual(3, s.send(b"foo"))
      self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
                                   ECONNABORTED)

      # A destroyed socket is no longer connected, but still usable.
      self.assertRaisesErrno(EDESTADDRREQ, s.send, b"foo")
      self.assertEqual(3, s.sendto(b"foo", (addr, 53)))

      # Check that reads on unconnected sockets are also interrupted.
      self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
                                   ECONNABORTED)

      s.close()

class SockDestroyPermissionTest(SockDiagBaseTest):

  def CheckPermissions(self, socktype):
    s = socket(AF_INET6, socktype, 0)
    self.SelectInterface(s, random.choice(self.NETIDS), "mark")
    if socktype == SOCK_STREAM:
      s.listen(1)
      expectedstate = tcp_test.TCP_LISTEN
    else:
      s.connect((self.GetRemoteAddress(6), 53))
      expectedstate = tcp_test.TCP_ESTABLISHED

    with net_test.RunAsUid(12345):
      self.assertRaisesErrno(
          EPERM, self.sock_diag.CloseSocketFromFd, s)

    self.sock_diag.CloseSocketFromFd(s)
    self.assertRaises(ValueError, self.sock_diag.CloseSocketFromFd, s)

    s.close()


  def testUdp(self):
    self.CheckPermissions(SOCK_DGRAM)

  def testTcp(self):
    self.CheckPermissions(SOCK_STREAM)


class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest):

  """Tests SOCK_DIAG bytecode filters that use marks.

    Relevant kernel commits:
      upstream net-next:
        627cc4a net: diag: slightly refactor the inet_diag_bc_audit error checks.
        a52e95a net: diag: allow socket bytecode filters to match socket marks
        d545cac net: inet: diag: expose the socket mark to privileged processes.
  """

  def FilterEstablishedSockets(self, mark, mask):
    instructions = [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (mark, mask))]
    bytecode = self.sock_diag.PackBytecode(instructions)
    return self.sock_diag.DumpAllInetSockets(
        IPPROTO_TCP, bytecode, states=(1 << tcp_test.TCP_ESTABLISHED))

  def assertSamePorts(self, ports, diag_msgs):
    expected = sorted(ports)
    actual = sorted([msg[0].id.sport for msg in diag_msgs])
    self.assertEqual(expected, actual)

  def SockInfoMatchesSocket(self, s, info):
    try:
      self.assertSockInfoMatchesSocket(s, info)
      return True
    except AssertionError:
      return False

  @staticmethod
  def SocketDescription(s):
    return "%s -> %s" % (str(s.getsockname()), str(s.getpeername()))

  def assertFoundSockets(self, infos, sockets):
    matches = {}
    for s in sockets:
      match = None
      for info in infos:
        if self.SockInfoMatchesSocket(s, info):
          if match:
            self.fail("Socket %s matched both %s and %s" %
                      (self.SocketDescription(s), match, info))
          matches[s] = info
      self.assertTrue(s in matches, "Did not find socket %s in dump" %
                      self.SocketDescription(s))

    for i in infos:
       if i not in list(matches.values()):
         self.fail("Too many sockets in dump, first unexpected: %s" % str(i))

  def testMarkBytecode(self):
    family, addr = random.choice([
        (AF_INET, "127.0.0.1"),
        (AF_INET6, "::1"),
        (AF_INET6, "::ffff:127.0.0.1")])
    s1, s2 = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
    s1.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xfff1234)
    s2.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xf0f1235)

    infos = self.FilterEstablishedSockets(0x1234, 0xffff)
    self.assertFoundSockets(infos, [s1])

    infos = self.FilterEstablishedSockets(0x1234, 0xfffe)
    self.assertFoundSockets(infos, [s1, s2])

    infos = self.FilterEstablishedSockets(0x1235, 0xffff)
    self.assertFoundSockets(infos, [s2])

    infos = self.FilterEstablishedSockets(0x0, 0x0)
    self.assertFoundSockets(infos, [s1, s2])

    infos = self.FilterEstablishedSockets(0xfff0000, 0xf0fed00)
    self.assertEqual(0, len(infos))

    with net_test.RunAsUid(12345):
        self.assertRaisesErrno(EPERM, self.FilterEstablishedSockets,
                               0xfff0000, 0xf0fed00)

    s1.close()
    s2.close()

  @staticmethod
  def SetRandomMark(s):
    # Python doesn't like marks that don't fit into a signed int.
    mark = random.randrange(0, 2**31 - 1)
    s.setsockopt(SOL_SOCKET, net_test.SO_MARK, mark)
    return mark

  def assertSocketMarkIs(self, s, mark):
    diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s)
    self.assertMarkIs(mark, attrs)
    with net_test.RunAsUid(12345):
      diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s)
      self.assertMarkIs(None, attrs)

  def testMarkInAttributes(self):
    testcases = [(AF_INET, "127.0.0.1"),
                 (AF_INET6, "::1"),
                 (AF_INET6, "::ffff:127.0.0.1")]
    for family, addr in testcases:
      # TCP listen sockets.
      server = socket(family, SOCK_STREAM, 0)
      server.bind((addr, 0))
      port = server.getsockname()[1]
      server.listen(1)  # Or the socket won't be in the hashtables.
      server_mark = self.SetRandomMark(server)
      self.assertSocketMarkIs(server, server_mark)

      # TCP client sockets.
      client = socket(family, SOCK_STREAM, 0)
      client_mark = self.SetRandomMark(client)
      client.connect((addr, port))
      self.assertSocketMarkIs(client, client_mark)

      # TCP server sockets.
      accepted, _ = server.accept()
      self.assertSocketMarkIs(accepted, server_mark)

      accepted_mark = self.SetRandomMark(accepted)
      self.assertSocketMarkIs(accepted, accepted_mark)
      self.assertSocketMarkIs(server, server_mark)

      accepted.close()
      server.close()
      client.close()

      # Other TCP states are tested in SockDestroyTcpTest.

      # UDP sockets.
      s = socket(family, SOCK_DGRAM, 0)
      mark = self.SetRandomMark(s)
      s.connect(("", 53))
      self.assertSocketMarkIs(s, mark)
      s.close()

      # Basic test for SCTP. sctp_diag was only added in 4.7.
      if HAVE_SCTP:
        s = socket(family, SOCK_STREAM, IPPROTO_SCTP)
        s.bind((addr, 0))
        s.listen(1)
        mark = self.SetRandomMark(s)
        self.assertSocketMarkIs(s, mark)
        sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_SCTP, NO_BYTECODE)
        self.assertEqual(1, len(sockets))
        self.assertEqual(mark, sockets[0][1].get("INET_DIAG_MARK", None))
        s.close()


if __name__ == "__main__":
  unittest.main()
