# mypy: ignore-errors

# Owner(s): ["module: dataloader"]

import copy
import itertools
import os
import os.path
import pickle
import pydoc
import random
import sys
import tempfile
import warnings
from functools import partial
from typing import (
    Any,
    Awaitable,
    Dict,
    Generic,
    Iterator,
    List,
    Optional,
    Set,
    Tuple,
    Type,
    TYPE_CHECKING,
    TypeVar,
    Union,
)

if not TYPE_CHECKING:
    # pyre isn't treating this the same as a typing.NamedTuple
    from typing_extensions import NamedTuple
else:
    from typing import NamedTuple

import operator
from unittest import skipIf

import numpy as np

import torch
import torch.nn as nn
import torch.utils.data.datapipes as dp
import torch.utils.data.graph
import torch.utils.data.graph_settings
from torch.testing._internal.common_utils import (
    run_tests,
    skipIfNoDill,
    skipIfTorchDynamo,
    suppress_warnings,
    TEST_DILL,
    TestCase,
)
from torch.utils._import_utils import import_dill
from torch.utils.data import (
    argument_validation,
    DataChunk,
    DataLoader,
    IterDataPipe,
    MapDataPipe,
    RandomSampler,
    runtime_validation,
    runtime_validation_disabled,
)
from torch.utils.data.datapipes.dataframe import (
    CaptureDataFrame,
    dataframe_wrapper as df_wrapper,
)
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
from torch.utils.data.datapipes.utils.common import StreamWrapper
from torch.utils.data.datapipes.utils.decoder import (
    basichandlers as decoder_basichandlers,
)
from torch.utils.data.datapipes.utils.snapshot import _simple_graph_snapshot_restoration
from torch.utils.data.graph import traverse_dps

dill = import_dill()
HAS_DILL = TEST_DILL

try:
    import pandas  # type: ignore[import]  # noqa: F401 F403

    HAS_PANDAS = True
except ImportError:
    HAS_PANDAS = False
skipIfNoDataFrames = skipIf(not HAS_PANDAS, "no dataframes (pandas)")

skipTyping = skipIf(True, "TODO: Fix typing bug")
T_co = TypeVar("T_co", covariant=True)


def create_temp_dir_and_files():
    # The temp dir and files within it will be released and deleted in tearDown().
    # Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function.
    temp_dir = tempfile.TemporaryDirectory()  # noqa: P201
    temp_dir_path = temp_dir.name
    with tempfile.NamedTemporaryFile(
        dir=temp_dir_path, delete=False, suffix=".txt"
    ) as f:
        temp_file1_name = f.name
    with tempfile.NamedTemporaryFile(
        dir=temp_dir_path, delete=False, suffix=".byte"
    ) as f:
        temp_file2_name = f.name
    with tempfile.NamedTemporaryFile(
        dir=temp_dir_path, delete=False, suffix=".empty"
    ) as f:
        temp_file3_name = f.name

    with open(temp_file1_name, "w") as f1:
        f1.write("0123456789abcdef")
    with open(temp_file2_name, "wb") as f2:
        f2.write(b"0123456789abcdef")

    temp_sub_dir = tempfile.TemporaryDirectory(dir=temp_dir_path)  # noqa: P201
    temp_sub_dir_path = temp_sub_dir.name
    with tempfile.NamedTemporaryFile(
        dir=temp_sub_dir_path, delete=False, suffix=".txt"
    ) as f:
        temp_sub_file1_name = f.name
    with tempfile.NamedTemporaryFile(
        dir=temp_sub_dir_path, delete=False, suffix=".byte"
    ) as f:
        temp_sub_file2_name = f.name

    with open(temp_sub_file1_name, "w") as f1:
        f1.write("0123456789abcdef")
    with open(temp_sub_file2_name, "wb") as f2:
        f2.write(b"0123456789abcdef")

    return [
        (temp_dir, temp_file1_name, temp_file2_name, temp_file3_name),
        (temp_sub_dir, temp_sub_file1_name, temp_sub_file2_name),
    ]


def reset_after_n_next_calls(
    datapipe: Union[IterDataPipe[T_co], MapDataPipe[T_co]], n: int
) -> Tuple[List[T_co], List[T_co]]:
    """
    Given a DataPipe and integer n, iterate the DataPipe for n elements and store the elements into a list
    Then, reset the DataPipe and return a tuple of two lists
        1. A list of elements yielded before the reset
        2. A list of all elements of the DataPipe after the reset
    """
    it = iter(datapipe)
    res_before_reset = []
    for _ in range(n):
        res_before_reset.append(next(it))
    return res_before_reset, list(datapipe)


def odd_or_even(x: int) -> int:
    return x % 2


class TestDataChunk(TestCase):
    def setUp(self):
        self.elements = list(range(10))
        random.shuffle(self.elements)
        self.chunk: DataChunk[int] = DataChunk(self.elements)

    def test_getitem(self):
        for i in range(10):
            self.assertEqual(self.elements[i], self.chunk[i])

    def test_iter(self):
        for ele, dc in zip(self.elements, iter(self.chunk)):
            self.assertEqual(ele, dc)

    def test_len(self):
        self.assertEqual(len(self.elements), len(self.chunk))

    def test_as_string(self):
        self.assertEqual(str(self.chunk), str(self.elements))

        batch = [self.elements] * 3
        chunks: List[DataChunk[int]] = [DataChunk(self.elements)] * 3
        self.assertEqual(str(batch), str(chunks))

    def test_sort(self):
        chunk: DataChunk[int] = DataChunk(self.elements)
        chunk.sort()
        self.assertTrue(isinstance(chunk, DataChunk))
        for i, d in enumerate(chunk):
            self.assertEqual(i, d)

    def test_reverse(self):
        chunk: DataChunk[int] = DataChunk(self.elements)
        chunk.reverse()
        self.assertTrue(isinstance(chunk, DataChunk))
        for i in range(10):
            self.assertEqual(chunk[i], self.elements[9 - i])

    def test_random_shuffle(self):
        elements = list(range(10))
        chunk: DataChunk[int] = DataChunk(elements)

        rng = random.Random(0)
        rng.shuffle(chunk)

        rng = random.Random(0)
        rng.shuffle(elements)

        self.assertEqual(chunk, elements)


class TestStreamWrapper(TestCase):
    class _FakeFD:
        def __init__(self, filepath):
            self.filepath = filepath
            self.opened = False
            self.closed = False

        def open(self):
            self.opened = True

        def read(self):
            if self.opened:
                return "".join(self)
            else:
                raise OSError("Cannot read from un-opened file descriptor")

        def __iter__(self):
            for i in range(5):
                yield str(i)

        def close(self):
            if self.opened:
                self.opened = False
                self.closed = True

        def __repr__(self):
            return "FakeFD"

    def test_dir(self):
        fd = TestStreamWrapper._FakeFD("")
        wrap_fd = StreamWrapper(fd)

        s = set(dir(wrap_fd))
        for api in ["open", "read", "close"]:
            self.assertTrue(api in s)

    @skipIfTorchDynamo()
    def test_api(self):
        fd = TestStreamWrapper._FakeFD("")
        wrap_fd = StreamWrapper(fd)

        self.assertFalse(fd.opened)
        self.assertFalse(fd.closed)
        with self.assertRaisesRegex(IOError, "Cannot read from"):
            wrap_fd.read()

        wrap_fd.open()
        self.assertTrue(fd.opened)
        self.assertEqual("01234", wrap_fd.read())

        del wrap_fd
        self.assertFalse(fd.opened)
        self.assertTrue(fd.closed)

    def test_pickle(self):
        with tempfile.TemporaryFile() as f:
            with self.assertRaises(TypeError) as ctx1:
                pickle.dumps(f)

            wrap_f = StreamWrapper(f)
            with self.assertRaises(TypeError) as ctx2:
                pickle.dumps(wrap_f)

            # Same exception when pickle
            self.assertEqual(str(ctx1.exception), str(ctx2.exception))

        fd = TestStreamWrapper._FakeFD("")
        wrap_fd = StreamWrapper(fd)
        _ = pickle.loads(pickle.dumps(wrap_fd))

    def test_repr(self):
        fd = TestStreamWrapper._FakeFD("")
        wrap_fd = StreamWrapper(fd)
        self.assertEqual(str(wrap_fd), "StreamWrapper<FakeFD>")

        with tempfile.TemporaryFile() as f:
            wrap_f = StreamWrapper(f)
            self.assertEqual(str(wrap_f), "StreamWrapper<" + str(f) + ">")


class TestIterableDataPipeBasic(TestCase):
    def setUp(self):
        ret = create_temp_dir_and_files()
        self.temp_dir = ret[0][0]
        self.temp_files = ret[0][1:]
        self.temp_sub_dir = ret[1][0]
        self.temp_sub_files = ret[1][1:]

    def tearDown(self):
        try:
            self.temp_sub_dir.cleanup()
            self.temp_dir.cleanup()
        except Exception as e:
            warnings.warn(
                f"TestIterableDatasetBasic was not able to cleanup temp dir due to {str(e)}"
            )

    def test_listdirfiles_iterable_datapipe(self):
        temp_dir = self.temp_dir.name
        datapipe: IterDataPipe = dp.iter.FileLister(temp_dir, "")

        count = 0
        for pathname in datapipe:
            count = count + 1
            self.assertTrue(pathname in self.temp_files)
        self.assertEqual(count, len(self.temp_files))

        count = 0
        datapipe = dp.iter.FileLister(temp_dir, "", recursive=True)
        for pathname in datapipe:
            count = count + 1
            self.assertTrue(
                (pathname in self.temp_files) or (pathname in self.temp_sub_files)
            )
        self.assertEqual(count, len(self.temp_files) + len(self.temp_sub_files))

        temp_files = self.temp_files
        datapipe = dp.iter.FileLister([temp_dir, *temp_files])
        count = 0
        for pathname in datapipe:
            count += 1
            self.assertTrue(pathname in self.temp_files)
        self.assertEqual(count, 2 * len(self.temp_files))

        # test functional API
        datapipe = datapipe.list_files()
        count = 0
        for pathname in datapipe:
            count += 1
            self.assertTrue(pathname in self.temp_files)
        self.assertEqual(count, 2 * len(self.temp_files))

    def test_listdirfilesdeterministic_iterable_datapipe(self):
        temp_dir = self.temp_dir.name

        datapipe = dp.iter.FileLister(temp_dir, "")
        # The output order should be always the same.
        self.assertEqual(list(datapipe), list(datapipe))

        datapipe = dp.iter.FileLister(temp_dir, "", recursive=True)
        # The output order should be always the same.
        self.assertEqual(list(datapipe), list(datapipe))

    def test_openfilesfromdisk_iterable_datapipe(self):
        # test import datapipe class directly
        from torch.utils.data.datapipes.iter import FileLister, FileOpener

        temp_dir = self.temp_dir.name
        datapipe1 = FileLister(temp_dir, "")
        datapipe2 = FileOpener(datapipe1, mode="b")

        count = 0
        for rec in datapipe2:
            count = count + 1
            self.assertTrue(rec[0] in self.temp_files)
            with open(rec[0], "rb") as f:
                self.assertEqual(rec[1].read(), f.read())
                rec[1].close()
        self.assertEqual(count, len(self.temp_files))

        # functional API
        datapipe3 = datapipe1.open_files(mode="b")

        count = 0
        for rec in datapipe3:
            count = count + 1
            self.assertTrue(rec[0] in self.temp_files)
            with open(rec[0], "rb") as f:
                self.assertEqual(rec[1].read(), f.read())
                rec[1].close()
        self.assertEqual(count, len(self.temp_files))

        # __len__ Test
        with self.assertRaises(TypeError):
            len(datapipe3)

    def test_routeddecoder_iterable_datapipe(self):
        temp_dir = self.temp_dir.name
        temp_pngfile_pathname = os.path.join(temp_dir, "test_png.png")
        png_data = np.array(
            [[[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]], [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]],
            dtype=np.single,
        )
        np.save(temp_pngfile_pathname, png_data)
        datapipe1 = dp.iter.FileLister(temp_dir, ["*.png", "*.txt"])
        datapipe2 = dp.iter.FileOpener(datapipe1, mode="b")

        def _png_decoder(extension, data):
            if extension != "png":
                return None
            return np.load(data)

        def _helper(prior_dp, dp, channel_first=False):
            # Byte stream is not closed
            for inp in prior_dp:
                self.assertFalse(inp[1].closed)
            for inp, rec in zip(prior_dp, dp):
                ext = os.path.splitext(rec[0])[1]
                if ext == ".png":
                    expected = np.array(
                        [
                            [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]],
                            [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]],
                        ],
                        dtype=np.single,
                    )
                    if channel_first:
                        expected = expected.transpose(2, 0, 1)
                    self.assertEqual(rec[1], expected)
                else:
                    with open(rec[0], "rb") as f:
                        self.assertEqual(rec[1], f.read().decode("utf-8"))
                # Corresponding byte stream is closed by Decoder
                self.assertTrue(inp[1].closed)

        cached = list(datapipe2)
        with warnings.catch_warnings(record=True) as wa:
            datapipe3 = dp.iter.RoutedDecoder(cached, _png_decoder)
        datapipe3.add_handler(decoder_basichandlers)
        _helper(cached, datapipe3)

        cached = list(datapipe2)
        with warnings.catch_warnings(record=True) as wa:
            datapipe4 = dp.iter.RoutedDecoder(cached, decoder_basichandlers)
        datapipe4.add_handler(_png_decoder)
        _helper(cached, datapipe4, channel_first=True)

    def test_groupby_iterable_datapipe(self):
        file_list = [
            "a.png",
            "b.png",
            "c.json",
            "a.json",
            "c.png",
            "b.json",
            "d.png",
            "d.json",
            "e.png",
            "f.json",
            "g.png",
            "f.png",
            "g.json",
            "e.json",
            "h.txt",
            "h.json",
        ]

        import io

        datapipe1 = dp.iter.IterableWrapper(
            [(filename, io.BytesIO(b"12345abcde")) for filename in file_list]
        )

        def group_fn(data):
            filepath, _ = data
            return os.path.basename(filepath).split(".")[0]

        datapipe2 = dp.iter.Grouper(datapipe1, group_key_fn=group_fn, group_size=2)

        def order_fn(data):
            data.sort(key=lambda f: f[0], reverse=True)
            return data

        datapipe3 = dp.iter.Mapper(datapipe2, fn=order_fn)  # type: ignore[var-annotated]

        expected_result = [
            ("a.png", "a.json"),
            ("c.png", "c.json"),
            ("b.png", "b.json"),
            ("d.png", "d.json"),
            ("f.png", "f.json"),
            ("g.png", "g.json"),
            ("e.png", "e.json"),
            ("h.txt", "h.json"),
        ]

        count = 0
        for rec, expected in zip(datapipe3, expected_result):
            count = count + 1
            self.assertEqual(os.path.basename(rec[0][0]), expected[0])
            self.assertEqual(os.path.basename(rec[1][0]), expected[1])
            for i in [0, 1]:
                self.assertEqual(rec[i][1].read(), b"12345abcde")
                rec[i][1].close()
        self.assertEqual(count, 8)

        # testing the keep_key option
        datapipe4 = dp.iter.Grouper(
            datapipe1, group_key_fn=group_fn, keep_key=True, group_size=2
        )

        def order_fn(data):
            data[1].sort(key=lambda f: f[0], reverse=True)
            return data

        datapipe5 = dp.iter.Mapper(datapipe4, fn=order_fn)  # type: ignore[var-annotated]

        expected_result = [
            ("a", ("a.png", "a.json")),
            ("c", ("c.png", "c.json")),
            ("b", ("b.png", "b.json")),
            ("d", ("d.png", "d.json")),
            ("f", ("f.png", "f.json")),
            ("g", ("g.png", "g.json")),
            ("e", ("e.png", "e.json")),
            ("h", ("h.txt", "h.json")),
        ]

        count = 0
        for rec, expected in zip(datapipe5, expected_result):
            count = count + 1
            self.assertEqual(rec[0], expected[0])
            self.assertEqual(rec[1][0][0], expected[1][0])
            self.assertEqual(rec[1][1][0], expected[1][1])
            for i in [0, 1]:
                self.assertEqual(rec[1][i][1].read(), b"12345abcde")
                rec[1][i][1].close()
        self.assertEqual(count, 8)

    def test_demux_mux_datapipe(self):
        numbers = NumbersDataset(10)
        n1, n2 = numbers.demux(2, lambda x: x % 2)
        self.assertEqual([0, 2, 4, 6, 8], list(n1))
        self.assertEqual([1, 3, 5, 7, 9], list(n2))

        # Functional Test: demux and mux works sequentially as expected
        numbers = NumbersDataset(10)
        n1, n2, n3 = numbers.demux(3, lambda x: x % 3)
        n = n1.mux(n2, n3)
        self.assertEqual(list(range(9)), list(n))

        # Functional Test: Uneven DataPipes
        source_numbers = list(range(0, 10)) + [10, 12]
        numbers_dp = dp.iter.IterableWrapper(source_numbers)
        n1, n2 = numbers_dp.demux(2, lambda x: x % 2)
        self.assertEqual([0, 2, 4, 6, 8, 10, 12], list(n1))
        self.assertEqual([1, 3, 5, 7, 9], list(n2))
        n = n1.mux(n2)
        self.assertEqual(list(range(10)), list(n))

    @suppress_warnings  # Suppress warning for lambda fn
    def test_map_with_col_file_handle_datapipe(self):
        temp_dir = self.temp_dir.name
        datapipe1 = dp.iter.FileLister(temp_dir, "")
        datapipe2 = dp.iter.FileOpener(datapipe1)

        def _helper(datapipe):
            dp1 = datapipe.map(lambda x: x.read(), input_col=1)
            dp2 = datapipe.map(lambda x: (x[0], x[1].read()))
            self.assertEqual(list(dp1), list(dp2))

        # tuple
        _helper(datapipe2)
        # list
        datapipe3 = datapipe2.map(lambda x: list(x))
        _helper(datapipe3)


@skipIfNoDataFrames
class TestCaptureDataFrame(TestCase):
    def get_new_df(self):
        return df_wrapper.create_dataframe([[1, 2]], columns=["a", "b"])

    def compare_capture_and_eager(self, operations):
        cdf = CaptureDataFrame()
        cdf = operations(cdf)
        df = self.get_new_df()
        cdf = cdf.apply_ops(df)

        df = self.get_new_df()
        df = operations(df)

        self.assertTrue(df.equals(cdf))

    def test_basic_capture(self):
        def operations(df):
            df["c"] = df.b + df["a"] * 7
            # somehow swallows pandas UserWarning when `df.c = df.b + df['a'] * 7`
            return df

        self.compare_capture_and_eager(operations)


class TestDataFramesPipes(TestCase):
    """
    Most of test will fail if pandas instaled, but no dill available.
    Need to rework them to avoid multiple skips.
    """

    def _get_datapipe(self, range=10, dataframe_size=7):
        return NumbersDataset(range).map(lambda i: (i, i % 3))

    def _get_dataframes_pipe(self, range=10, dataframe_size=7):
        return (
            NumbersDataset(range)
            .map(lambda i: (i, i % 3))
            ._to_dataframes_pipe(columns=["i", "j"], dataframe_size=dataframe_size)
        )

    @skipIfNoDataFrames
    @skipIfNoDill  # TODO(VitalyFedyunin): Decouple tests from dill by avoiding lambdas in map
    def test_capture(self):
        dp_numbers = self._get_datapipe().map(lambda x: (x[0], x[1], x[1] + 3 * x[0]))
        df_numbers = self._get_dataframes_pipe()
        df_numbers["k"] = df_numbers["j"] + df_numbers.i * 3
        expected = list(dp_numbers)
        actual = list(df_numbers)
        self.assertEqual(expected, actual)

    @skipIfNoDataFrames
    @skipIfNoDill
    def test_shuffle(self):
        #  With non-zero (but extremely low) probability (when shuffle do nothing),
        #  this test fails, so feel free to restart
        df_numbers = self._get_dataframes_pipe(range=1000).shuffle()
        dp_numbers = self._get_datapipe(range=1000)
        df_result = [tuple(item) for item in df_numbers]
        self.assertNotEqual(list(dp_numbers), df_result)
        self.assertEqual(list(dp_numbers), sorted(df_result))

    @skipIfNoDataFrames
    @skipIfNoDill
    def test_batch(self):
        df_numbers = self._get_dataframes_pipe(range=100).batch(8)
        df_numbers_list = list(df_numbers)
        last_batch = df_numbers_list[-1]
        self.assertEqual(4, len(last_batch))
        unpacked_batch = [tuple(row) for row in last_batch]
        self.assertEqual([(96, 0), (97, 1), (98, 2), (99, 0)], unpacked_batch)

    @skipIfNoDataFrames
    @skipIfNoDill
    def test_unbatch(self):
        df_numbers = self._get_dataframes_pipe(range=100).batch(8).batch(3)
        dp_numbers = self._get_datapipe(range=100)
        self.assertEqual(list(dp_numbers), list(df_numbers.unbatch(2)))

    @skipIfNoDataFrames
    @skipIfNoDill
    def test_filter(self):
        df_numbers = self._get_dataframes_pipe(range=10).filter(lambda x: x.i > 5)
        actual = list(df_numbers)
        self.assertEqual([(6, 0), (7, 1), (8, 2), (9, 0)], actual)

    @skipIfNoDataFrames
    @skipIfNoDill
    def test_collate(self):
        def collate_i(column):
            return column.sum()

        def collate_j(column):
            return column.prod()

        df_numbers = self._get_dataframes_pipe(range=30).batch(3)
        df_numbers = df_numbers.collate({"j": collate_j, "i": collate_i})

        expected_i = [
            3,
            12,
            21,
            30,
            39,
            48,
            57,
            66,
            75,
            84,
        ]

        actual_i = []
        for i, j in df_numbers:
            actual_i.append(i)
        self.assertEqual(expected_i, actual_i)

        actual_i = []
        for item in df_numbers:
            actual_i.append(item.i)
        self.assertEqual(expected_i, actual_i)


class IDP_NoLen(IterDataPipe):
    def __init__(self, input_dp):
        super().__init__()
        self.input_dp = input_dp

    # Prevent in-place modification
    def __iter__(self):
        input_dp = (
            self.input_dp
            if isinstance(self.input_dp, IterDataPipe)
            else copy.deepcopy(self.input_dp)
        )
        yield from input_dp


def _fake_fn(data):
    return data


def _fake_add(constant, data):
    return constant + data


def _fake_filter_fn(data):
    return True


def _simple_filter_fn(data):
    return data >= 5


def _fake_filter_fn_constant(constant, data):
    return data >= constant


def _mul_10(x):
    return x * 10


def _mod_3_test(x):
    return x % 3 == 1


def _to_list(x):
    return [x]


lambda_fn1 = lambda x: x  # noqa: E731
lambda_fn2 = lambda x: x % 2  # noqa: E731
lambda_fn3 = lambda x: x >= 5  # noqa: E731


class Add1Module(nn.Module):
    def forward(self, x):
        return x + 1


class Add1Callable:
    def __call__(self, x):
        return x + 1


class TestFunctionalIterDataPipe(TestCase):
    def _serialization_test_helper(self, datapipe, use_dill):
        if use_dill:
            serialized_dp = dill.dumps(datapipe)
            deserialized_dp = dill.loads(serialized_dp)
        else:
            serialized_dp = pickle.dumps(datapipe)
            deserialized_dp = pickle.loads(serialized_dp)
        try:
            self.assertEqual(list(datapipe), list(deserialized_dp))
        except AssertionError as e:
            print(f"{datapipe} is failing.")
            raise e

    def _serialization_test_for_single_dp(self, dp, use_dill=False):
        # 1. Testing for serialization before any iteration starts
        self._serialization_test_helper(dp, use_dill)
        # 2. Testing for serialization after DataPipe is partially read
        it = iter(dp)
        _ = next(it)
        self._serialization_test_helper(dp, use_dill)
        # 3. Testing for serialization after DataPipe is fully read
        it = iter(dp)
        _ = list(it)
        self._serialization_test_helper(dp, use_dill)

    def _serialization_test_for_dp_with_children(self, dp1, dp2, use_dill=False):
        # 1. Testing for serialization before any iteration starts
        self._serialization_test_helper(dp1, use_dill)
        self._serialization_test_helper(dp2, use_dill)

        # 2. Testing for serialization after DataPipe is partially read
        it1, it2 = iter(dp1), iter(dp2)
        _, _ = next(it1), next(it2)
        # Catch `fork`, `demux` "some child DataPipes are not exhausted" warning
        with warnings.catch_warnings(record=True) as wa:
            self._serialization_test_helper(dp1, use_dill)
            self._serialization_test_helper(dp2, use_dill)

        # 2.5. Testing for serialization after one child DataPipe is fully read
        #      (Only for DataPipes with children DataPipes)
        it1 = iter(dp1)
        _ = list(it1)  # fully read one child
        # Catch `fork`, `demux` "some child DataPipes are not exhausted" warning
        with warnings.catch_warnings(record=True) as wa:
            self._serialization_test_helper(dp1, use_dill)
            self._serialization_test_helper(dp2, use_dill)

        # 3. Testing for serialization after DataPipe is fully read
        it2 = iter(dp2)
        _ = list(it2)  # fully read the other child
        self._serialization_test_helper(dp1, use_dill)
        self._serialization_test_helper(dp2, use_dill)

    def test_serializable(self):
        picklable_datapipes: List = [
            (
                dp.iter.Batcher,
                None,
                (
                    3,
                    True,
                ),
                {},
            ),
            (dp.iter.Collator, None, (_fake_fn,), {}),
            (dp.iter.Concater, None, (dp.iter.IterableWrapper(range(5)),), {}),
            (dp.iter.Demultiplexer, None, (2, _simple_filter_fn), {}),
            (dp.iter.FileLister, ".", (), {}),
            (dp.iter.FileOpener, None, (), {}),
            (dp.iter.Filter, None, (_fake_filter_fn,), {}),
            (dp.iter.Filter, None, (partial(_fake_filter_fn_constant, 5),), {}),
            (dp.iter.Forker, None, (2,), {}),
            (dp.iter.Forker, None, (2,), {"copy": "shallow"}),
            (dp.iter.Grouper, None, (_fake_filter_fn,), {"group_size": 2}),
            (dp.iter.IterableWrapper, range(10), (), {}),
            (dp.iter.Mapper, None, (_fake_fn,), {}),
            (dp.iter.Mapper, None, (partial(_fake_add, 1),), {}),
            (dp.iter.Multiplexer, None, (dp.iter.IterableWrapper(range(10)),), {}),
            (dp.iter.Sampler, None, (), {}),
            (dp.iter.Shuffler, dp.iter.IterableWrapper([0] * 10), (), {}),
            (dp.iter.StreamReader, None, (), {}),
            (dp.iter.UnBatcher, None, (0,), {}),
            (dp.iter.Zipper, None, (dp.iter.IterableWrapper(range(10)),), {}),
        ]
        # Skipping comparison for these DataPipes
        dp_skip_comparison = {dp.iter.FileOpener, dp.iter.StreamReader}
        # These DataPipes produce multiple DataPipes as outputs and those should be compared
        dp_compare_children = {dp.iter.Demultiplexer, dp.iter.Forker}

        for dpipe, custom_input, dp_args, dp_kwargs in picklable_datapipes:
            if custom_input is None:
                custom_input = dp.iter.IterableWrapper(range(10))
            if (
                dpipe in dp_skip_comparison
            ):  # Merely make sure they are picklable and loadable (no value comparison)
                datapipe = dpipe(custom_input, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
                serialized_dp = pickle.dumps(datapipe)
                _ = pickle.loads(serialized_dp)
            elif dpipe in dp_compare_children:  # DataPipes that have children
                dp1, dp2 = dpipe(custom_input, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
                self._serialization_test_for_dp_with_children(dp1, dp2)
            else:  # Single DataPipe that requires comparison
                datapipe = dpipe(custom_input, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
                self._serialization_test_for_single_dp(datapipe)

    @skipIfTorchDynamo("Dict with function as keys")
    def test_serializable_with_dill(self):
        """Only for DataPipes that take in a function as argument"""
        input_dp = dp.iter.IterableWrapper(range(10))

        datapipes_with_lambda_fn: List[
            Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]
        ] = [
            (dp.iter.Collator, (lambda_fn1,), {}),
            (
                dp.iter.Demultiplexer,
                (
                    2,
                    lambda_fn2,
                ),
                {},
            ),
            (dp.iter.Filter, (lambda_fn3,), {}),
            (dp.iter.Grouper, (lambda_fn3,), {}),
            (dp.iter.Mapper, (lambda_fn1,), {}),
        ]

        def _local_fns():
            def _fn1(x):
                return x

            def _fn2(x):
                return x % 2

            def _fn3(x):
                return x >= 5

            return _fn1, _fn2, _fn3

        fn1, fn2, fn3 = _local_fns()

        datapipes_with_local_fn: List[
            Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]
        ] = [
            (dp.iter.Collator, (fn1,), {}),
            (
                dp.iter.Demultiplexer,
                (
                    2,
                    fn2,
                ),
                {},
            ),
            (dp.iter.Filter, (fn3,), {}),
            (dp.iter.Grouper, (fn3,), {}),
            (dp.iter.Mapper, (fn1,), {}),
        ]

        dp_compare_children = {dp.iter.Demultiplexer}

        if HAS_DILL:
            for dpipe, dp_args, dp_kwargs in (
                datapipes_with_lambda_fn + datapipes_with_local_fn
            ):
                if dpipe in dp_compare_children:
                    dp1, dp2 = dpipe(input_dp, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
                    self._serialization_test_for_dp_with_children(
                        dp1, dp2, use_dill=True
                    )
                else:
                    datapipe = dpipe(input_dp, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
                    self._serialization_test_for_single_dp(datapipe, use_dill=True)
        else:
            msgs = (
                r"^Lambda function is not supported by pickle",
                r"^Local function is not supported by pickle",
            )
            for dps, msg in zip(
                (datapipes_with_lambda_fn, datapipes_with_local_fn), msgs
            ):
                for dpipe, dp_args, dp_kwargs in dps:
                    with self.assertWarnsRegex(UserWarning, msg):
                        datapipe = dpipe(input_dp, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
                    with self.assertRaises((pickle.PicklingError, AttributeError)):
                        pickle.dumps(datapipe)

    def test_docstring(self):
        """
        Ensure functional form of IterDataPipe has the correct docstring from
        the class form.

        Regression test for https://github.com/pytorch/data/issues/792.
        """
        input_dp = dp.iter.IterableWrapper(range(10))

        for dp_funcname in [
            "batch",
            "collate",
            "concat",
            "demux",
            "filter",
            "fork",
            "map",
            "mux",
            "read_from_stream",
            # "sampler",
            "shuffle",
            "unbatch",
            "zip",
        ]:
            if sys.version_info >= (3, 9):
                docstring = pydoc.render_doc(
                    thing=getattr(input_dp, dp_funcname), forceload=True
                )
            elif sys.version_info < (3, 9):
                # pydoc works differently on Python 3.8, see
                # https://docs.python.org/3/whatsnew/3.9.html#pydoc
                docstring = getattr(input_dp, dp_funcname).__doc__

            assert f"(functional name: ``{dp_funcname}``)" in docstring
            assert "Args:" in docstring
            assert "Example:" in docstring or "Examples:" in docstring

    def test_iterable_wrapper_datapipe(self):
        input_ls = list(range(10))
        input_dp = dp.iter.IterableWrapper(input_ls)

        # Functional Test: values are unchanged and in the same order
        self.assertEqual(input_ls, list(input_dp))

        # Functional Test: deep copy by default when an iterator is initialized (first element is read)
        it = iter(input_dp)
        self.assertEqual(
            0, next(it)
        )  # The deep copy only happens when the first element is read
        input_ls.append(50)
        self.assertEqual(list(range(1, 10)), list(it))

        # Functional Test: shallow copy
        input_ls2 = [1, 2, 3]
        input_dp_shallow = dp.iter.IterableWrapper(input_ls2, deepcopy=False)
        input_ls2.append(10)
        self.assertEqual([1, 2, 3, 10], list(input_dp_shallow))

        # Reset Test: reset the DataPipe
        input_ls = list(range(10))
        input_dp = dp.iter.IterableWrapper(input_ls)
        n_elements_before_reset = 5
        res_before_reset, res_after_reset = reset_after_n_next_calls(
            input_dp, n_elements_before_reset
        )
        self.assertEqual(input_ls[:n_elements_before_reset], res_before_reset)
        self.assertEqual(input_ls, res_after_reset)

        # __len__ Test: inherits length from sequence
        self.assertEqual(len(input_ls), len(input_dp))

    def test_concat_iterdatapipe(self):
        input_dp1 = dp.iter.IterableWrapper(range(10))
        input_dp2 = dp.iter.IterableWrapper(range(5))

        # Functional Test: Raises exception for empty input
        with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"):
            dp.iter.Concater()

        # Functional Test: Raises exception for non-IterDataPipe input
        with self.assertRaisesRegex(
            TypeError, r"Expected all inputs to be `IterDataPipe`"
        ):
            dp.iter.Concater(input_dp1, ())  # type: ignore[arg-type]

        # Functional Test: Concatenate DataPipes as expected
        concat_dp = input_dp1.concat(input_dp2)
        self.assertEqual(len(concat_dp), 15)
        self.assertEqual(list(concat_dp), list(range(10)) + list(range(5)))

        # Reset Test: reset the DataPipe
        n_elements_before_reset = 5
        res_before_reset, res_after_reset = reset_after_n_next_calls(
            concat_dp, n_elements_before_reset
        )
        self.assertEqual(list(range(5)), res_before_reset)
        self.assertEqual(list(range(10)) + list(range(5)), res_after_reset)

        # __len__ Test: inherits length from source DataPipe
        input_dp_nl = IDP_NoLen(range(5))
        concat_dp = input_dp1.concat(input_dp_nl)
        with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
            len(concat_dp)

        self.assertEqual(list(concat_dp), list(range(10)) + list(range(5)))

    def test_fork_iterdatapipe(self):
        input_dp = dp.iter.IterableWrapper(range(10))

        with self.assertRaises(ValueError):
            input_dp.fork(num_instances=0)

        dp0 = input_dp.fork(num_instances=1, buffer_size=0)
        self.assertEqual(dp0, input_dp)

        # Functional Test: making sure all child DataPipe shares the same reference
        dp1, dp2, dp3 = input_dp.fork(num_instances=3)
        self.assertTrue(all(n1 is n2 and n1 is n3 for n1, n2, n3 in zip(dp1, dp2, dp3)))

        # Functional Test: one child DataPipe yields all value at a time
        output1, output2, output3 = list(dp1), list(dp2), list(dp3)
        self.assertEqual(list(range(10)), output1)
        self.assertEqual(list(range(10)), output2)
        self.assertEqual(list(range(10)), output3)

        # Functional Test: two child DataPipes yield value together
        dp1, dp2 = input_dp.fork(num_instances=2)
        output = []
        for n1, n2 in zip(dp1, dp2):
            output.append((n1, n2))
        self.assertEqual([(i, i) for i in range(10)], output)

        # Functional Test: one child DataPipe yields all value first, but buffer_size = 5 being too small
        dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=4)
        it1 = iter(dp1)
        for _ in range(4):
            next(it1)
        with self.assertRaises(BufferError):
            next(it1)
        with self.assertRaises(BufferError):
            list(dp2)

        dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=5)
        with self.assertRaises(BufferError):
            list(dp2)

        # Functional Test: one child DataPipe yields all value first with unlimited buffer
        with warnings.catch_warnings(record=True) as wa:
            dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=-1)
            self.assertEqual(len(wa), 1)
            self.assertRegex(str(wa[0].message), r"Unlimited buffer size is set")
        l1, l2 = list(dp1), list(dp2)
        for d1, d2 in zip(l1, l2):
            self.assertEqual(d1, d2)

        # Functional Test: two child DataPipes yield value together with buffer size 1
        dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=1)
        output = []
        for n1, n2 in zip(dp1, dp2):
            output.append((n1, n2))
        self.assertEqual([(i, i) for i in range(10)], output)

        # Functional Test: two child DataPipes yield shallow copies with copy equals shallow
        dp1, dp2 = input_dp.map(_to_list).fork(num_instances=2, copy="shallow")
        for n1, n2 in zip(dp1, dp2):
            self.assertIsNot(n1, n2)
            self.assertEqual(n1, n2)

        # Functional Test: two child DataPipes yield deep copies with copy equals deep
        dp1, dp2 = (
            input_dp.map(_to_list).map(_to_list).fork(num_instances=2, copy="deep")
        )
        for n1, n2 in zip(dp1, dp2):
            self.assertIsNot(n1[0], n2[0])
            self.assertEqual(n1, n2)

        # Functional Test: fork DataPipe raises error for unknown copy method
        with self.assertRaises(ValueError):
            input_dp.fork(num_instances=2, copy="unknown")

        # Functional Test: make sure logic related to slowest_ptr is working properly
        dp1, dp2, dp3 = input_dp.fork(num_instances=3)
        output1, output2, output3 = [], [], []
        for i, (n1, n2) in enumerate(zip(dp1, dp2)):
            output1.append(n1)
            output2.append(n2)
            if i == 4:  # yield all of dp3 when halfway through dp1, dp2
                output3 = list(dp3)
                break
        self.assertEqual(list(range(5)), output1)
        self.assertEqual(list(range(5)), output2)
        self.assertEqual(list(range(10)), output3)

        # Reset Test: DataPipe resets when a new iterator is created, even if this datapipe hasn't been read
        dp1, dp2 = input_dp.fork(num_instances=2)
        _ = iter(dp1)
        output2 = []
        with self.assertRaisesRegex(RuntimeError, r"iterator has been invalidated"):
            for i, n2 in enumerate(dp2):
                output2.append(n2)
                if i == 4:
                    with warnings.catch_warnings(record=True) as wa:
                        _ = iter(dp1)  # This will reset all child DataPipes
                        self.assertEqual(len(wa), 1)
                        self.assertRegex(
                            str(wa[0].message), r"child DataPipes are not exhausted"
                        )
        self.assertEqual(list(range(5)), output2)

        # Reset Test: DataPipe resets when some of it has been read
        dp1, dp2 = input_dp.fork(num_instances=2)
        output1, output2 = [], []
        for i, (n1, n2) in enumerate(zip(dp1, dp2)):
            output1.append(n1)
            output2.append(n2)
            if i == 4:
                with warnings.catch_warnings(record=True) as wa:
                    _ = iter(dp1)  # Reset both all child DataPipe
                    self.assertEqual(len(wa), 1)
                    self.assertRegex(
                        str(wa[0].message), r"Some child DataPipes are not exhausted"
                    )
                break
        with warnings.catch_warnings(record=True) as wa:
            for i, (n1, n2) in enumerate(zip(dp1, dp2)):
                output1.append(n1)
                output2.append(n2)
            self.assertEqual(len(wa), 1)
            self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted")
        self.assertEqual(list(range(5)) + list(range(10)), output1)
        self.assertEqual(list(range(5)) + list(range(10)), output2)

        # Reset Test: DataPipe reset, even when some other child DataPipes are not read
        dp1, dp2, dp3 = input_dp.fork(num_instances=3)
        output1, output2 = list(dp1), list(dp2)
        self.assertEqual(list(range(10)), output1)
        self.assertEqual(list(range(10)), output2)
        with warnings.catch_warnings(record=True) as wa:
            self.assertEqual(
                list(range(10)), list(dp1)
            )  # Resets even though dp3 has not been read
            self.assertEqual(len(wa), 1)
            self.assertRegex(
                str(wa[0].message), r"Some child DataPipes are not exhausted"
            )
        output3 = []
        for i, n3 in enumerate(dp3):
            output3.append(n3)
            if i == 4:
                with warnings.catch_warnings(record=True) as wa:
                    output1 = list(dp1)  # Resets even though dp3 is only partially read
                    self.assertEqual(len(wa), 1)
                    self.assertRegex(
                        str(wa[0].message), r"Some child DataPipes are not exhausted"
                    )
                self.assertEqual(list(range(5)), output3)
                self.assertEqual(list(range(10)), output1)
                break
        self.assertEqual(
            list(range(10)), list(dp3)
        )  # dp3 has to read from the start again

        # __len__ Test: Each DataPipe inherits the source datapipe's length
        dp1, dp2, dp3 = input_dp.fork(num_instances=3)
        self.assertEqual(len(input_dp), len(dp1))
        self.assertEqual(len(input_dp), len(dp2))
        self.assertEqual(len(input_dp), len(dp3))

        # Pickle Test:
        dp1, dp2, dp3 = input_dp.fork(num_instances=3)
        traverse_dps(dp1)  # This should not raise any error
        for _ in zip(dp1, dp2, dp3):
            pass
        traverse_dps(dp2)  # This should not raise any error either

    def test_mux_iterdatapipe(self):
        # Functional Test: Elements are yielded one at a time from each DataPipe, until they are all exhausted
        input_dp1 = dp.iter.IterableWrapper(range(4))
        input_dp2 = dp.iter.IterableWrapper(range(4, 8))
        input_dp3 = dp.iter.IterableWrapper(range(8, 12))
        output_dp = input_dp1.mux(input_dp2, input_dp3)
        expected_output = [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11]
        self.assertEqual(len(expected_output), len(output_dp))
        self.assertEqual(expected_output, list(output_dp))

        # Functional Test: Uneven input Data Pipes
        input_dp1 = dp.iter.IterableWrapper([1, 2, 3, 4])
        input_dp2 = dp.iter.IterableWrapper([10])
        input_dp3 = dp.iter.IterableWrapper([100, 200, 300])
        output_dp = input_dp1.mux(input_dp2, input_dp3)
        expected_output = [1, 10, 100]
        self.assertEqual(len(expected_output), len(output_dp))
        self.assertEqual(expected_output, list(output_dp))

        # Functional Test: Empty Data Pipe
        input_dp1 = dp.iter.IterableWrapper([0, 1, 2, 3])
        input_dp2 = dp.iter.IterableWrapper([])
        output_dp = input_dp1.mux(input_dp2)
        self.assertEqual(len(input_dp2), len(output_dp))
        self.assertEqual(list(input_dp2), list(output_dp))

        # __len__ Test: raises TypeError when __len__ is called and an input doesn't have __len__
        input_dp1 = dp.iter.IterableWrapper(range(10))
        input_dp_no_len = IDP_NoLen(range(10))
        output_dp = input_dp1.mux(input_dp_no_len)
        with self.assertRaises(TypeError):
            len(output_dp)

    def test_demux_iterdatapipe(self):
        input_dp = dp.iter.IterableWrapper(range(10))

        with self.assertRaises(ValueError):
            input_dp.demux(num_instances=0, classifier_fn=lambda x: 0)

        # Functional Test: split into 2 DataPipes and output them one at a time
        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
        output1, output2 = list(dp1), list(dp2)
        self.assertEqual(list(range(0, 10, 2)), output1)
        self.assertEqual(list(range(1, 10, 2)), output2)

        # Functional Test: split into 2 DataPipes and output them together
        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
        output = []
        for n1, n2 in zip(dp1, dp2):
            output.append((n1, n2))
        self.assertEqual([(i, i + 1) for i in range(0, 10, 2)], output)

        # Functional Test: values of the same classification are lumped together, and buffer_size = 3 being too small
        dp1, dp2 = input_dp.demux(
            num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=4
        )
        it1 = iter(dp1)
        with self.assertRaises(BufferError):
            next(
                it1
            )  # Buffer raises because first 5 elements all belong to the a different child
        with self.assertRaises(BufferError):
            list(dp2)

        # Functional Test: values of the same classification are lumped together, and buffer_size = 5 is just enough
        dp1, dp2 = input_dp.demux(
            num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=5
        )
        output1, output2 = list(dp1), list(dp2)
        self.assertEqual(list(range(5, 10)), output1)
        self.assertEqual(list(range(0, 5)), output2)

        # Functional Test: values of the same classification are lumped together, and unlimited buffer
        with warnings.catch_warnings(record=True) as wa:
            dp1, dp2 = input_dp.demux(
                num_instances=2,
                classifier_fn=lambda x: 0 if x >= 5 else 1,
                buffer_size=-1,
            )
            exp_l = 1 if HAS_DILL else 2
            self.assertEqual(len(wa), exp_l)
            self.assertRegex(str(wa[-1].message), r"Unlimited buffer size is set")
        output1, output2 = list(dp1), list(dp2)
        self.assertEqual(list(range(5, 10)), output1)
        self.assertEqual(list(range(0, 5)), output2)

        # Functional Test: classifier returns a value outside of [0, num_instance - 1]
        dp0 = input_dp.demux(num_instances=1, classifier_fn=lambda x: x % 2)
        it = iter(dp0[0])
        with self.assertRaises(ValueError):
            next(it)
            next(it)

        # Reset Test: DataPipe resets when a new iterator is created, even if this datapipe hasn't been read
        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
        _ = iter(dp1)
        output2 = []
        with self.assertRaisesRegex(RuntimeError, r"iterator has been invalidated"):
            for i, n2 in enumerate(dp2):
                output2.append(n2)
                if i == 4:
                    with warnings.catch_warnings(record=True) as wa:
                        _ = iter(dp1)  # This will reset all child DataPipes
                        self.assertEqual(len(wa), 1)
                        self.assertRegex(
                            str(wa[0].message), r"child DataPipes are not exhausted"
                        )
        self.assertEqual(list(range(1, 10, 2)), output2)

        # Reset Test: DataPipe resets when some of it has been read
        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
        output1, output2 = [], []
        for n1, n2 in zip(dp1, dp2):
            output1.append(n1)
            output2.append(n2)
            if n1 == 4:
                break
        with warnings.catch_warnings(record=True) as wa:
            i1 = iter(dp1)  # Reset all child DataPipes
            self.assertEqual(len(wa), 1)
            self.assertRegex(
                str(wa[0].message), r"Some child DataPipes are not exhausted"
            )
            for n1, n2 in zip(dp1, dp2):
                output1.append(n1)
                output2.append(n2)
            self.assertEqual([0, 2, 4] + list(range(0, 10, 2)), output1)
            self.assertEqual([1, 3, 5] + list(range(1, 10, 2)), output2)
            self.assertEqual(len(wa), 1)
            self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted")

        # Reset Test: DataPipe reset, even when not all child DataPipes are exhausted
        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
        output1 = list(dp1)
        self.assertEqual(list(range(0, 10, 2)), output1)
        with warnings.catch_warnings(record=True) as wa:
            self.assertEqual(
                list(range(0, 10, 2)), list(dp1)
            )  # Reset even when dp2 is not read
            self.assertEqual(len(wa), 1)
            self.assertRegex(
                str(wa[0].message), r"Some child DataPipes are not exhausted"
            )
        output2 = []
        for i, n2 in enumerate(dp2):
            output2.append(n2)
            if i == 1:
                self.assertEqual(list(range(1, 5, 2)), output2)
                with warnings.catch_warnings(record=True) as wa:
                    self.assertEqual(
                        list(range(0, 10, 2)), list(dp1)
                    )  # Can reset even when dp2 is partially read
                    self.assertEqual(len(wa), 1)
                    self.assertRegex(
                        str(wa[0].message), r"Some child DataPipes are not exhausted"
                    )
                break
        output2 = list(dp2)  # output2 has to read from beginning again
        self.assertEqual(list(range(1, 10, 2)), output2)

        # Functional Test: drop_none = True
        dp1, dp2 = input_dp.demux(
            num_instances=2,
            classifier_fn=lambda x: x % 2 if x % 5 != 0 else None,
            drop_none=True,
        )
        self.assertEqual([2, 4, 6, 8], list(dp1))
        self.assertEqual([1, 3, 7, 9], list(dp2))

        # Functional Test: drop_none = False
        dp1, dp2 = input_dp.demux(
            num_instances=2,
            classifier_fn=lambda x: x % 2 if x % 5 != 0 else None,
            drop_none=False,
        )
        it1 = iter(dp1)
        with self.assertRaises(ValueError):
            next(it1)

        # __len__ Test: __len__ not implemented
        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
        with self.assertRaises(TypeError):
            len(
                dp1
            )  # It is not implemented as we do not know length for each child in advance
        with self.assertRaises(TypeError):
            len(dp2)

        # Pickle Test:
        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=odd_or_even)
        traverse_dps(dp1)  # This should not raise any error
        for _ in zip(dp1, dp2):
            pass
        traverse_dps(dp2)  # This should not raise any error either

    def test_map_iterdatapipe(self):
        target_length = 10
        input_dp = dp.iter.IterableWrapper(range(target_length))

        def fn(item, dtype=torch.float, *, sum=False):
            data = torch.tensor(item, dtype=dtype)
            return data if not sum else data.sum()

        # Functional Test: apply to each element correctly
        map_dp = input_dp.map(fn)
        self.assertEqual(target_length, len(map_dp))
        for x, y in zip(map_dp, range(target_length)):
            self.assertEqual(x, torch.tensor(y, dtype=torch.float))

        # Functional Test: works with partial function
        map_dp = input_dp.map(partial(fn, dtype=torch.int, sum=True))
        for x, y in zip(map_dp, range(target_length)):
            self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())

        # __len__ Test: inherits length from source DataPipe
        self.assertEqual(target_length, len(map_dp))

        input_dp_nl = IDP_NoLen(range(target_length))
        map_dp_nl = input_dp_nl.map(lambda x: x)
        for x, y in zip(map_dp_nl, range(target_length)):
            self.assertEqual(x, torch.tensor(y, dtype=torch.float))

        # __len__ Test: inherits length from source DataPipe - raises error when invalid
        with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
            len(map_dp_nl)

        # Reset Test: DataPipe resets properly
        n_elements_before_reset = 5
        res_before_reset, res_after_reset = reset_after_n_next_calls(
            map_dp, n_elements_before_reset
        )
        self.assertEqual(list(range(n_elements_before_reset)), res_before_reset)
        self.assertEqual(list(range(10)), res_after_reset)

    @suppress_warnings  # Suppress warning for lambda fn
    def test_map_tuple_list_with_col_iterdatapipe(self):
        def fn_11(d):
            return -d

        def fn_1n(d):
            return -d, d

        def fn_n1(d0, d1):
            return d0 + d1

        def fn_nn(d0, d1):
            return -d0, -d1, d0 + d1

        def fn_n1_def(d0, d1=1):
            return d0 + d1

        def fn_n1_kwargs(d0, d1, **kwargs):
            return d0 + d1

        def fn_n1_pos(d0, d1, *args):
            return d0 + d1

        def fn_n1_sep_pos(d0, *args, d1):
            return d0 + d1

        def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
            return d0 + d1

        p_fn_n1 = partial(fn_n1, d1=1)
        p_fn_cmplx = partial(fn_cmplx, d2=2)
        p_fn_cmplx_large_arg = partial(
            fn_cmplx, d2={i: list(range(i)) for i in range(10_000)}
        )

        def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
            for constr in (list, tuple):
                datapipe = dp.iter.IterableWrapper(
                    [constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))]
                )
                if ref_fn is None:
                    with self.assertRaises(error):
                        res_dp = datapipe.map(fn, input_col, output_col)
                        list(res_dp)
                else:
                    res_dp = datapipe.map(fn, input_col, output_col)
                    ref_dp = datapipe.map(ref_fn)
                    self.assertEqual(list(res_dp), list(ref_dp))
                    # Reset
                    self.assertEqual(list(res_dp), list(ref_dp))

        _helper(lambda data: data, fn_n1_def, 0, 1)
        _helper(
            lambda data: (data[0], data[1], data[0] + data[1]), fn_n1_def, [0, 1], 2
        )
        _helper(lambda data: data, p_fn_n1, 0, 1)
        _helper(lambda data: data, p_fn_cmplx, 0, 1)
        _helper(lambda data: data, p_fn_cmplx_large_arg, 0, 1)
        _helper(
            lambda data: (data[0], data[1], data[0] + data[1]), p_fn_cmplx, [0, 1], 2
        )
        _helper(lambda data: (data[0] + data[1],), fn_n1_pos, [0, 1, 2])

        # Replacing with one input column and default output column
        _helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1)
        _helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1)
        # The index of input column is out of range
        _helper(None, fn_1n, 3, error=IndexError)
        # Unmatched input columns with fn arguments
        _helper(None, fn_n1, 1, error=ValueError)
        _helper(None, fn_n1, [0, 1, 2], error=ValueError)
        _helper(None, operator.add, 0, error=ValueError)
        _helper(None, operator.add, [0, 1, 2], error=ValueError)
        _helper(None, fn_cmplx, 0, 1, ValueError)
        _helper(None, fn_n1_pos, 1, error=ValueError)
        _helper(None, fn_n1_def, [0, 1, 2], 1, error=ValueError)
        _helper(None, p_fn_n1, [0, 1], error=ValueError)
        _helper(None, fn_1n, [1, 2], error=ValueError)
        # _helper(None, p_fn_cmplx, [0, 1, 2], error=ValueError)
        _helper(None, fn_n1_sep_pos, [0, 1, 2], error=ValueError)
        # Fn has keyword-only arguments
        _helper(None, fn_n1_kwargs, 1, error=ValueError)
        _helper(None, fn_cmplx, [0, 1], 2, ValueError)

        # Replacing with multiple input columns and default output column (the left-most input column)
        _helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0])
        _helper(
            lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])),
            fn_nn,
            [2, 1],
        )

        # output_col can only be specified when input_col is not None
        _helper(None, fn_n1, None, 1, error=ValueError)
        # output_col can only be single-element list or tuple
        _helper(None, fn_n1, None, [0, 1], error=ValueError)
        # Single-element list as output_col
        _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0])
        # Replacing with one input column and single specified output column
        _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0)
        _helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2)
        # The index of output column is out of range
        _helper(None, fn_1n, 1, 3, error=IndexError)
        _helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1)
        _helper(
            lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]),
            fn_nn,
            [1, 2],
            0,
        )

        # Appending the output at the end
        _helper(lambda data: (*data, -data[1]), fn_11, 1, -1)
        _helper(lambda data: (*data, (-data[1], data[1])), fn_1n, 1, -1)
        _helper(lambda data: (*data, data[0] + data[2]), fn_n1, [0, 2], -1)
        _helper(
            lambda data: (*data, (-data[1], -data[2], data[1] + data[2])),
            fn_nn,
            [1, 2],
            -1,
        )

        # Handling built-in functions (e.g. `dict`, `iter`, `int`, `str`) whose signatures cannot be inspected
        _helper(lambda data: (str(data[0]), data[1], data[2]), str, 0)
        _helper(lambda data: (data[0], data[1], int(data[2])), int, 2)

        # Handle nn.Module and Callable (without __name__ implemented)
        _helper(lambda data: (data[0] + 1, data[1], data[2]), Add1Module(), 0)
        _helper(lambda data: (data[0] + 1, data[1], data[2]), Add1Callable(), 0)

    @suppress_warnings  # Suppress warning for lambda fn
    @skipIfTorchDynamo()
    def test_map_dict_with_col_iterdatapipe(self):
        def fn_11(d):
            return -d

        def fn_1n(d):
            return -d, d

        def fn_n1(d0, d1):
            return d0 + d1

        def fn_nn(d0, d1):
            return -d0, -d1, d0 + d1

        def fn_n1_def(d0, d1=1):
            return d0 + d1

        p_fn_n1 = partial(fn_n1, d1=1)

        def fn_n1_pos(d0, d1, *args):
            return d0 + d1

        def fn_n1_kwargs(d0, d1, **kwargs):
            return d0 + d1

        def fn_kwonly(*, d0, d1):
            return d0 + d1

        def fn_has_nondefault_kwonly(d0, *, d1):
            return d0 + d1

        def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
            return d0 + d1

        p_fn_cmplx = partial(fn_cmplx, d2=2)
        p_fn_cmplx_large_arg = partial(
            fn_cmplx, d2={i: list(range(i)) for i in range(10_000)}
        )

        # Prevent modification in-place to support resetting
        def _dict_update(data, newdata, remove_idx=None):
            _data = dict(data)
            _data.update(newdata)
            if remove_idx:
                for idx in remove_idx:
                    del _data[idx]
            return _data

        def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
            datapipe = dp.iter.IterableWrapper(
                [
                    {"x": 0, "y": 1, "z": 2},
                    {"x": 3, "y": 4, "z": 5},
                    {"x": 6, "y": 7, "z": 8},
                ]
            )
            if ref_fn is None:
                with self.assertRaises(error):
                    res_dp = datapipe.map(fn, input_col, output_col)
                    list(res_dp)
            else:
                res_dp = datapipe.map(fn, input_col, output_col)
                ref_dp = datapipe.map(ref_fn)
                self.assertEqual(list(res_dp), list(ref_dp))
                # Reset
                self.assertEqual(list(res_dp), list(ref_dp))

        _helper(lambda data: data, fn_n1_def, "x", "y")
        _helper(lambda data: data, p_fn_n1, "x", "y")
        _helper(lambda data: data, p_fn_cmplx, "x", "y")
        _helper(lambda data: data, p_fn_cmplx_large_arg, "x", "y")
        _helper(
            lambda data: _dict_update(data, {"z": data["x"] + data["y"]}),
            p_fn_cmplx,
            ["x", "y", "z"],
            "z",
        )

        _helper(
            lambda data: _dict_update(data, {"z": data["x"] + data["y"]}),
            fn_n1_def,
            ["x", "y"],
            "z",
        )

        _helper(None, fn_n1_pos, "x", error=ValueError)
        _helper(None, fn_n1_kwargs, "x", error=ValueError)
        # non-default kw-only args
        _helper(None, fn_kwonly, ["x", "y"], error=ValueError)
        _helper(None, fn_has_nondefault_kwonly, ["x", "y"], error=ValueError)
        _helper(None, fn_cmplx, ["x", "y"], error=ValueError)

        # Replacing with one input column and default output column
        _helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
        _helper(
            lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y"
        )
        # The key of input column is not in dict
        _helper(None, fn_1n, "a", error=KeyError)
        # Unmatched input columns with fn arguments
        _helper(None, fn_n1, "y", error=ValueError)
        _helper(None, fn_1n, ["x", "y"], error=ValueError)
        _helper(None, fn_n1_def, ["x", "y", "z"], error=ValueError)
        _helper(None, p_fn_n1, ["x", "y"], error=ValueError)
        _helper(None, fn_n1_kwargs, ["x", "y", "z"], error=ValueError)
        # Replacing with multiple input columns and default output column (the left-most input column)
        _helper(
            lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]),
            fn_n1,
            ["z", "x"],
        )
        _helper(
            lambda data: _dict_update(
                data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"]
            ),
            fn_nn,
            ["z", "y"],
        )

        # output_col can only be specified when input_col is not None
        _helper(None, fn_n1, None, "x", error=ValueError)
        # output_col can only be single-element list or tuple
        _helper(None, fn_n1, None, ["x", "y"], error=ValueError)
        # Single-element list as output_col
        _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"])
        # Replacing with one input column and single specified output column
        _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", "x")
        _helper(
            lambda data: _dict_update(data, {"z": (-data["y"], data["y"])}),
            fn_1n,
            "y",
            "z",
        )
        _helper(
            lambda data: _dict_update(data, {"y": data["x"] + data["z"]}),
            fn_n1,
            ["x", "z"],
            "y",
        )
        _helper(
            lambda data: _dict_update(
                data, {"x": (-data["y"], -data["z"], data["y"] + data["z"])}
            ),
            fn_nn,
            ["y", "z"],
            "x",
        )

        # Adding new key to dict for the output
        _helper(lambda data: _dict_update(data, {"a": -data["y"]}), fn_11, "y", "a")
        _helper(
            lambda data: _dict_update(data, {"a": (-data["y"], data["y"])}),
            fn_1n,
            "y",
            "a",
        )
        _helper(
            lambda data: _dict_update(data, {"a": data["x"] + data["z"]}),
            fn_n1,
            ["x", "z"],
            "a",
        )
        _helper(
            lambda data: _dict_update(
                data, {"a": (-data["y"], -data["z"], data["y"] + data["z"])}
            ),
            fn_nn,
            ["y", "z"],
            "a",
        )

    def test_collate_iterdatapipe(self):
        arrs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
        input_dp = dp.iter.IterableWrapper(arrs)

        def _collate_fn(batch, default_type=torch.float):
            return torch.tensor(sum(batch), dtype=default_type)

        # Functional Test: defaults to the default collate function when a custom one is not specified
        collate_dp = input_dp.collate()
        for x, y in zip(arrs, collate_dp):
            self.assertEqual(torch.tensor(x), y)

        # Functional Test: custom collate function
        collate_dp = input_dp.collate(collate_fn=_collate_fn)
        for x, y in zip(arrs, collate_dp):
            self.assertEqual(torch.tensor(sum(x), dtype=torch.float), y)

        # Functional Test: custom, partial collate function
        collate_dp = input_dp.collate(partial(_collate_fn, default_type=torch.int))
        for x, y in zip(arrs, collate_dp):
            self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y)

        # Reset Test: reset the DataPipe and results are still correct
        n_elements_before_reset = 1
        res_before_reset, res_after_reset = reset_after_n_next_calls(
            collate_dp, n_elements_before_reset
        )
        self.assertEqual([torch.tensor(6, dtype=torch.int)], res_before_reset)
        for x, y in zip(arrs, res_after_reset):
            self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y)

        # __len__ Test: __len__ is inherited
        self.assertEqual(len(input_dp), len(collate_dp))

        # __len__ Test: verify that it has no valid __len__ when the source doesn't have it
        input_dp_nl = IDP_NoLen(arrs)
        collate_dp_nl = input_dp_nl.collate()
        with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
            len(collate_dp_nl)
        for x, y in zip(arrs, collate_dp_nl):
            self.assertEqual(torch.tensor(x), y)

    def test_batch_iterdatapipe(self):
        arrs = list(range(10))
        input_dp = dp.iter.IterableWrapper(arrs)

        # Functional Test: raise error when input argument `batch_size = 0`
        with self.assertRaises(AssertionError):
            input_dp.batch(batch_size=0)

        # Functional Test: by default, do not drop the last batch
        bs = 3
        batch_dp = input_dp.batch(batch_size=bs)
        self.assertEqual(len(batch_dp), 4)
        for i, batch in enumerate(batch_dp):
            self.assertEqual(len(batch), 1 if i == 3 else bs)
            self.assertEqual(batch, arrs[i * bs : i * bs + len(batch)])

        # Functional Test: Drop the last batch when specified
        bs = 4
        batch_dp = input_dp.batch(batch_size=bs, drop_last=True)
        for i, batch in enumerate(batch_dp):
            self.assertEqual(batch, arrs[i * bs : i * bs + len(batch)])

        # __len__ test: verifying that the overall length and of each batch is correct
        for i, batch in enumerate(batch_dp):
            self.assertEqual(len(batch), bs)

        # __len__ Test: the length is missing if the source DataPipe doesn't have length
        self.assertEqual(len(batch_dp), 2)
        input_dp_nl = IDP_NoLen(range(10))
        batch_dp_nl = input_dp_nl.batch(batch_size=2)
        with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
            len(batch_dp_nl)

        # Reset Test: Ensures that the DataPipe can properly reset
        n_elements_before_reset = 1
        res_before_reset, res_after_reset = reset_after_n_next_calls(
            batch_dp, n_elements_before_reset
        )
        self.assertEqual([[0, 1, 2, 3]], res_before_reset)
        self.assertEqual([[0, 1, 2, 3], [4, 5, 6, 7]], res_after_reset)

    def test_unbatch_iterdatapipe(self):
        target_length = 6
        prebatch_dp = dp.iter.IterableWrapper(range(target_length))

        # Functional Test: Unbatch DataPipe should be the same as pre-batch DataPipe
        input_dp = prebatch_dp.batch(3)
        unbatch_dp = input_dp.unbatch()
        self.assertEqual(len(list(unbatch_dp)), target_length)  # __len__ is as expected
        for i, res in zip(range(target_length), unbatch_dp):
            self.assertEqual(i, res)

        # Functional Test: unbatch works for an input with nested levels
        input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]])
        unbatch_dp = input_dp.unbatch()
        self.assertEqual(len(list(unbatch_dp)), target_length)
        for i, res in zip(range(target_length), unbatch_dp):
            self.assertEqual(i, res)

        input_dp = dp.iter.IterableWrapper([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])

        # Functional Test: unbatch works for an input with nested levels
        unbatch_dp = input_dp.unbatch()
        expected_dp = [[0, 1], [2, 3], [4, 5], [6, 7]]
        self.assertEqual(len(list(unbatch_dp)), 4)
        for j, res in zip(expected_dp, unbatch_dp):
            self.assertEqual(j, res)

        # Functional Test: unbatching multiple levels at the same time
        unbatch_dp = input_dp.unbatch(unbatch_level=2)
        expected_dp2 = [0, 1, 2, 3, 4, 5, 6, 7]
        self.assertEqual(len(list(unbatch_dp)), 8)
        for i, res in zip(expected_dp2, unbatch_dp):
            self.assertEqual(i, res)

        # Functional Test: unbatching all levels at the same time
        unbatch_dp = input_dp.unbatch(unbatch_level=-1)
        self.assertEqual(len(list(unbatch_dp)), 8)
        for i, res in zip(expected_dp2, unbatch_dp):
            self.assertEqual(i, res)

        # Functional Test: raises error when input unbatch_level is less than -1
        input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]])
        with self.assertRaises(ValueError):
            unbatch_dp = input_dp.unbatch(unbatch_level=-2)
            for i in unbatch_dp:
                print(i)

        # Functional Test: raises error when input unbatch_level is too high
        with self.assertRaises(IndexError):
            unbatch_dp = input_dp.unbatch(unbatch_level=5)
            for i in unbatch_dp:
                print(i)

        # Reset Test: unbatch_dp resets properly
        input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]])
        unbatch_dp = input_dp.unbatch(unbatch_level=-1)
        n_elements_before_reset = 3
        res_before_reset, res_after_reset = reset_after_n_next_calls(
            unbatch_dp, n_elements_before_reset
        )
        self.assertEqual([0, 1, 2], res_before_reset)
        self.assertEqual([0, 1, 2, 3, 4, 5], res_after_reset)

    def test_filter_datapipe(self):
        input_ds = dp.iter.IterableWrapper(range(10))

        def _filter_fn(data, val):
            return data >= val

        # Functional Test: filter works with partial function
        filter_dp = input_ds.filter(partial(_filter_fn, val=5))
        self.assertEqual(list(filter_dp), list(range(5, 10)))

        def _non_bool_fn(data):
            return 1

        # Functional Test: filter function must return bool
        filter_dp = input_ds.filter(filter_fn=_non_bool_fn)
        with self.assertRaises(ValueError):
            temp = list(filter_dp)

        # Funtional Test: Specify input_col
        tuple_input_ds = dp.iter.IterableWrapper([(d - 1, d, d + 1) for d in range(10)])

        # Single input_col
        input_col_1_dp = tuple_input_ds.filter(partial(_filter_fn, val=5), input_col=1)
        self.assertEqual(
            list(input_col_1_dp), [(d - 1, d, d + 1) for d in range(5, 10)]
        )

        # Multiple input_col
        def _mul_filter_fn(a, b):
            return a + b < 10

        input_col_2_dp = tuple_input_ds.filter(_mul_filter_fn, input_col=[0, 2])
        self.assertEqual(list(input_col_2_dp), [(d - 1, d, d + 1) for d in range(5)])

        # invalid input col
        with self.assertRaises(ValueError):
            tuple_input_ds.filter(_mul_filter_fn, input_col=0)

        p_mul_filter_fn = partial(_mul_filter_fn, b=1)
        out = tuple_input_ds.filter(p_mul_filter_fn, input_col=0)
        self.assertEqual(list(out), [(d - 1, d, d + 1) for d in range(10)])

        def _mul_filter_fn_with_defaults(a, b=1):
            return a + b < 10

        out = tuple_input_ds.filter(_mul_filter_fn_with_defaults, input_col=0)
        self.assertEqual(list(out), [(d - 1, d, d + 1) for d in range(10)])

        def _mul_filter_fn_with_kw_only(*, a, b):
            return a + b < 10

        with self.assertRaises(ValueError):
            tuple_input_ds.filter(_mul_filter_fn_with_kw_only, input_col=0)

        def _mul_filter_fn_with_kw_only_1_default(*, a, b=1):
            return a + b < 10

        with self.assertRaises(ValueError):
            tuple_input_ds.filter(_mul_filter_fn_with_kw_only_1_default, input_col=0)

        # __len__ Test: DataPipe has no valid len
        with self.assertRaisesRegex(TypeError, r"has no len"):
            len(filter_dp)

        # Reset Test: DataPipe resets correctly
        filter_dp = input_ds.filter(partial(_filter_fn, val=5))
        n_elements_before_reset = 3
        res_before_reset, res_after_reset = reset_after_n_next_calls(
            filter_dp, n_elements_before_reset
        )
        self.assertEqual(list(range(5, 10))[:n_elements_before_reset], res_before_reset)
        self.assertEqual(list(range(5, 10)), res_after_reset)

    def test_sampler_iterdatapipe(self):
        input_dp = dp.iter.IterableWrapper(range(10))
        # Default SequentialSampler
        sampled_dp = dp.iter.Sampler(input_dp)  # type: ignore[var-annotated]
        self.assertEqual(len(sampled_dp), 10)
        for i, x in enumerate(sampled_dp):
            self.assertEqual(x, i)

        # RandomSampler
        random_sampled_dp = dp.iter.Sampler(
            input_dp, sampler=RandomSampler, sampler_kwargs={"replacement": True}
        )  # type: ignore[var-annotated] # noqa: B950

        # Requires `__len__` to build SamplerDataPipe
        input_dp_nolen = IDP_NoLen(range(10))
        with self.assertRaises(AssertionError):
            sampled_dp = dp.iter.Sampler(input_dp_nolen)

    def test_stream_reader_iterdatapipe(self):
        from io import StringIO

        input_dp = dp.iter.IterableWrapper(
            [("f1", StringIO("abcde")), ("f2", StringIO("bcdef"))]
        )
        expected_res = ["abcde", "bcdef"]

        # Functional Test: Read full chunk
        dp1 = input_dp.read_from_stream()
        self.assertEqual([d[1] for d in dp1], expected_res)

        # Functional Test: Read full chunk
        dp2 = input_dp.read_from_stream(chunk=1)
        self.assertEqual([d[1] for d in dp2], [c for s in expected_res for c in s])

        # `__len__` Test
        with self.assertRaises(TypeError):
            len(dp1)

    def test_shuffler_iterdatapipe(self):
        input_dp = dp.iter.IterableWrapper(list(range(10)))

        with self.assertRaises(AssertionError):
            shuffle_dp = input_dp.shuffle(buffer_size=0)

        # Functional Test: No seed
        shuffler_dp = input_dp.shuffle()
        self.assertEqual(set(range(10)), set(shuffler_dp))

        # Functional Test: With global seed
        torch.manual_seed(123)
        shuffler_dp = input_dp.shuffle()
        res = list(shuffler_dp)
        torch.manual_seed(123)
        self.assertEqual(list(shuffler_dp), res)

        # Functional Test: Set seed
        shuffler_dp = input_dp.shuffle().set_seed(123)
        res = list(shuffler_dp)
        shuffler_dp.set_seed(123)
        self.assertEqual(list(shuffler_dp), res)

        # Functional Test: deactivate shuffling via set_shuffle
        unshuffled_dp = input_dp.shuffle().set_shuffle(False)
        self.assertEqual(list(unshuffled_dp), list(input_dp))

        # Reset Test:
        shuffler_dp = input_dp.shuffle()
        n_elements_before_reset = 5
        res_before_reset, res_after_reset = reset_after_n_next_calls(
            shuffler_dp, n_elements_before_reset
        )
        self.assertEqual(5, len(res_before_reset))
        for x in res_before_reset:
            self.assertTrue(x in set(range(10)))
        self.assertEqual(set(range(10)), set(res_after_reset))

        # __len__ Test: returns the length of the input DataPipe
        shuffler_dp = input_dp.shuffle()
        self.assertEqual(10, len(shuffler_dp))
        exp = list(range(100))

        # Serialization Test
        from torch.utils.data.datapipes._hook_iterator import _SnapshotState

        def _serialization_helper(bs):
            shuffler_dp = input_dp.shuffle(buffer_size=bs)
            it = iter(shuffler_dp)
            for _ in range(2):
                next(it)
            shuffler_dp_copy = pickle.loads(pickle.dumps(shuffler_dp))
            _simple_graph_snapshot_restoration(
                shuffler_dp_copy.datapipe,
                shuffler_dp.datapipe._number_of_samples_yielded,
            )

            exp = list(it)
            shuffler_dp_copy._snapshot_state = _SnapshotState.Restored
            self.assertEqual(exp, list(shuffler_dp_copy))

        buffer_sizes = [2, 5, 15]
        for bs in buffer_sizes:
            _serialization_helper(bs)

    def test_zip_iterdatapipe(self):
        # Functional Test: raises TypeError when an input is not of type `IterDataPipe`
        with self.assertRaises(TypeError):
            dp.iter.Zipper(dp.iter.IterableWrapper(range(10)), list(range(10)))  # type: ignore[arg-type]

        # Functional Test: raises TypeError when an input does not have valid length
        zipped_dp = dp.iter.Zipper(
            dp.iter.IterableWrapper(range(10)), IDP_NoLen(range(5))
        )  # type: ignore[var-annotated]
        with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
            len(zipped_dp)

        # Functional Test: zips the results properly
        exp = [(i, i) for i in range(5)]
        self.assertEqual(list(zipped_dp), exp)

        # Functional Test: zips the inputs properly even when lengths are different (zips to the shortest)
        zipped_dp = dp.iter.Zipper(
            dp.iter.IterableWrapper(range(10)), dp.iter.IterableWrapper(range(5))
        )

        # __len__ Test: length matches the length of the shortest input
        self.assertEqual(len(zipped_dp), 5)

        # Reset Test:
        n_elements_before_reset = 3
        res_before_reset, res_after_reset = reset_after_n_next_calls(
            zipped_dp, n_elements_before_reset
        )
        expected_res = [(i, i) for i in range(5)]
        self.assertEqual(expected_res[:n_elements_before_reset], res_before_reset)
        self.assertEqual(expected_res, res_after_reset)


class TestFunctionalMapDataPipe(TestCase):
    def _serialization_test_helper(self, datapipe, use_dill):
        if use_dill:
            serialized_dp = dill.dumps(datapipe)
            deserialized_dp = dill.loads(serialized_dp)
        else:
            serialized_dp = pickle.dumps(datapipe)
            deserialized_dp = pickle.loads(serialized_dp)
        try:
            self.assertEqual(list(datapipe), list(deserialized_dp))
        except AssertionError as e:
            print(f"{datapipe} is failing.")
            raise e

    def _serialization_test_for_single_dp(self, dp, use_dill=False):
        # 1. Testing for serialization before any iteration starts
        self._serialization_test_helper(dp, use_dill)
        # 2. Testing for serialization after DataPipe is partially read
        it = iter(dp)
        _ = next(it)
        self._serialization_test_helper(dp, use_dill)
        # 3. Testing for serialization after DataPipe is fully read
        _ = list(dp)
        self._serialization_test_helper(dp, use_dill)

    def test_serializable(self):
        picklable_datapipes: List = [
            (dp.map.Batcher, None, (2,), {}),
            (dp.map.Concater, None, (dp.map.SequenceWrapper(range(10)),), {}),
            (dp.map.Mapper, None, (), {}),
            (dp.map.Mapper, None, (_fake_fn,), {}),
            (dp.map.Mapper, None, (partial(_fake_add, 1),), {}),
            (dp.map.SequenceWrapper, range(10), (), {}),
            (dp.map.Shuffler, dp.map.SequenceWrapper([0] * 5), (), {}),
            (dp.map.Zipper, None, (dp.map.SequenceWrapper(range(10)),), {}),
        ]
        for dpipe, custom_input, dp_args, dp_kwargs in picklable_datapipes:
            if custom_input is None:
                custom_input = dp.map.SequenceWrapper(range(10))
            datapipe = dpipe(custom_input, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
            self._serialization_test_for_single_dp(datapipe)

    def test_serializable_with_dill(self):
        """Only for DataPipes that take in a function as argument"""
        input_dp = dp.map.SequenceWrapper(range(10))

        datapipes_with_lambda_fn: List[
            Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
        ] = [
            (dp.map.Mapper, (lambda_fn1,), {}),
        ]

        def _local_fns():
            def _fn1(x):
                return x

            return _fn1

        fn1 = _local_fns()

        datapipes_with_local_fn: List[
            Tuple[Type[MapDataPipe], Tuple, Dict[str, Any]]
        ] = [
            (dp.map.Mapper, (fn1,), {}),
        ]

        if HAS_DILL:
            for dpipe, dp_args, dp_kwargs in (
                datapipes_with_lambda_fn + datapipes_with_local_fn
            ):
                _ = dill.dumps(dpipe(input_dp, *dp_args, **dp_kwargs))  # type: ignore[call-arg]
        else:
            msgs = (
                r"^Lambda function is not supported by pickle",
                r"^Local function is not supported by pickle",
            )
            for dps, msg in zip(
                (datapipes_with_lambda_fn, datapipes_with_local_fn), msgs
            ):
                for dpipe, dp_args, dp_kwargs in dps:
                    with self.assertWarnsRegex(UserWarning, msg):
                        datapipe = dpipe(input_dp, *dp_args, **dp_kwargs)  # type: ignore[call-arg]
                    with self.assertRaises((pickle.PicklingError, AttributeError)):
                        pickle.dumps(datapipe)

    def test_docstring(self):
        """
        Ensure functional form of MapDataPipe has the correct docstring from
        the class form.

        Regression test for https://github.com/pytorch/data/issues/792.
        """
        input_dp = dp.map.SequenceWrapper(range(10))

        for dp_funcname in [
            "batch",
            "concat",
            "map",
            "shuffle",
            "zip",
        ]:
            if sys.version_info >= (3, 9):
                docstring = pydoc.render_doc(
                    thing=getattr(input_dp, dp_funcname), forceload=True
                )
            elif sys.version_info < (3, 9):
                # pydoc works differently on Python 3.8, see
                # https://docs.python.org/3/whatsnew/3.9.html#pydoc
                docstring = getattr(input_dp, dp_funcname).__doc__
            assert f"(functional name: ``{dp_funcname}``)" in docstring
            assert "Args:" in docstring
            assert "Example:" in docstring or "Examples:" in docstring

    def test_sequence_wrapper_datapipe(self):
        seq = list(range(10))
        input_dp = dp.map.SequenceWrapper(seq)

        # Functional Test: all elements are equal in the same order
        self.assertEqual(seq, list(input_dp))

        # Functional Test: confirm deepcopy works by default
        seq.append(11)
        self.assertEqual(list(range(10)), list(input_dp))  # input_dp shouldn't have 11

        # Functional Test: non-deepcopy version is working
        seq2 = [1, 2, 3]
        input_dp_non_deep = dp.map.SequenceWrapper(seq2, deepcopy=False)
        seq2.append(4)
        self.assertEqual(list(seq2), list(input_dp_non_deep))  # should have 4

        # Reset Test: reset the DataPipe
        seq = list(range(10))
        n_elements_before_reset = 5
        res_before_reset, res_after_reset = reset_after_n_next_calls(
            input_dp, n_elements_before_reset
        )
        self.assertEqual(list(range(5)), res_before_reset)
        self.assertEqual(seq, res_after_reset)

        # __len__ Test: inherits length from sequence
        self.assertEqual(len(seq), len(input_dp))

    def test_concat_mapdatapipe(self):
        input_dp1 = dp.map.SequenceWrapper(range(10))
        input_dp2 = dp.map.SequenceWrapper(range(5))

        with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"):
            dp.map.Concater()

        with self.assertRaisesRegex(
            TypeError, r"Expected all inputs to be `MapDataPipe`"
        ):
            dp.map.Concater(input_dp1, ())  # type: ignore[arg-type]

        concat_dp = input_dp1.concat(input_dp2)
        self.assertEqual(len(concat_dp), 15)
        for index in range(15):
            self.assertEqual(
                concat_dp[index], (list(range(10)) + list(range(5)))[index]
            )
        self.assertEqual(list(concat_dp), list(range(10)) + list(range(5)))

    def test_zip_mapdatapipe(self):
        input_dp1 = dp.map.SequenceWrapper(range(10))
        input_dp2 = dp.map.SequenceWrapper(range(5))
        input_dp3 = dp.map.SequenceWrapper(range(15))

        # Functional Test: requires at least one input DataPipe
        with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"):
            dp.map.Zipper()

        # Functional Test: all inputs must be MapDataPipes
        with self.assertRaisesRegex(
            TypeError, r"Expected all inputs to be `MapDataPipe`"
        ):
            dp.map.Zipper(input_dp1, ())  # type: ignore[arg-type]

        # Functional Test: Zip the elements up as a tuples
        zip_dp = input_dp1.zip(input_dp2, input_dp3)
        self.assertEqual([(i, i, i) for i in range(5)], [zip_dp[i] for i in range(5)])

        # Functional Test: Raise IndexError when index equal or exceed the length of the shortest DataPipe
        with self.assertRaisesRegex(IndexError, r"out of range"):
            input_dp1.zip(input_dp2, input_dp3)[5]

        # Functional Test: Ensure `zip` can combine `Batcher` with others
        dp1 = dp.map.SequenceWrapper(range(10))
        shuffle_dp1 = dp1.batch(2)
        dp2 = dp.map.SequenceWrapper(range(10))
        shuffle_dp2 = dp2.batch(3)
        zip_dp1 = shuffle_dp1.zip(shuffle_dp2)
        self.assertEqual(4, len(list(zip_dp1)))
        zip_dp2 = shuffle_dp1.zip(dp2)
        self.assertEqual(5, len(list(zip_dp2)))

        # __len__ Test: returns the length of the shortest DataPipe
        zip_dp = input_dp1.zip(input_dp2, input_dp3)
        self.assertEqual(5, len(zip_dp))

    def test_shuffler_mapdatapipe(self):
        input_dp1 = dp.map.SequenceWrapper(range(10))
        input_dp2 = dp.map.SequenceWrapper({"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})

        # Functional Test: Assumes 0-index when indices is not given
        shuffler_dp = input_dp1.shuffle()
        self.assertEqual(set(range(10)), set(shuffler_dp))

        # Functional Test: Custom indices are working
        shuffler_dp = input_dp2.shuffle(indices=["a", "b", "c", "d", "e"])
        self.assertEqual(set(range(1, 6)), set(shuffler_dp))

        # Functional Test: With global seed
        torch.manual_seed(123)
        shuffler_dp = input_dp1.shuffle()
        res = list(shuffler_dp)
        torch.manual_seed(123)
        self.assertEqual(list(shuffler_dp), res)

        # Functional Test: Set seed
        shuffler_dp = input_dp1.shuffle().set_seed(123)
        res = list(shuffler_dp)
        shuffler_dp.set_seed(123)
        self.assertEqual(list(shuffler_dp), res)

        # Functional Test: deactivate shuffling via set_shuffle
        unshuffled_dp = input_dp1.shuffle().set_shuffle(False)
        self.assertEqual(list(unshuffled_dp), list(input_dp1))

        # Reset Test:
        shuffler_dp = input_dp1.shuffle()
        n_elements_before_reset = 5
        res_before_reset, res_after_reset = reset_after_n_next_calls(
            shuffler_dp, n_elements_before_reset
        )
        self.assertEqual(5, len(res_before_reset))
        for x in res_before_reset:
            self.assertTrue(x in set(range(10)))
        self.assertEqual(set(range(10)), set(res_after_reset))

        # __len__ Test: returns the length of the input DataPipe
        shuffler_dp = input_dp1.shuffle()
        self.assertEqual(10, len(shuffler_dp))

        # Serialization Test
        from torch.utils.data.datapipes._hook_iterator import _SnapshotState

        shuffler_dp = input_dp1.shuffle()
        it = iter(shuffler_dp)
        for _ in range(2):
            next(it)
        shuffler_dp_copy = pickle.loads(pickle.dumps(shuffler_dp))

        exp = list(it)
        shuffler_dp_copy._snapshot_state = _SnapshotState.Restored
        self.assertEqual(exp, list(shuffler_dp_copy))

    def test_map_mapdatapipe(self):
        arr = range(10)
        input_dp = dp.map.SequenceWrapper(arr)

        def fn(item, dtype=torch.float, *, sum=False):
            data = torch.tensor(item, dtype=dtype)
            return data if not sum else data.sum()

        map_dp = input_dp.map(fn)
        self.assertEqual(len(input_dp), len(map_dp))
        for index in arr:
            self.assertEqual(
                map_dp[index], torch.tensor(input_dp[index], dtype=torch.float)
            )

        map_dp = input_dp.map(partial(fn, dtype=torch.int, sum=True))
        self.assertEqual(len(input_dp), len(map_dp))
        for index in arr:
            self.assertEqual(
                map_dp[index], torch.tensor(input_dp[index], dtype=torch.int).sum()
            )

    def test_batch_mapdatapipe(self):
        arr = list(range(13))
        input_dp = dp.map.SequenceWrapper(arr)

        # Functional Test: batches top level by default
        batch_dp = dp.map.Batcher(input_dp, batch_size=2)
        self.assertEqual(
            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12]], list(batch_dp)
        )

        # Functional Test: drop_last on command
        batch_dp = dp.map.Batcher(input_dp, batch_size=2, drop_last=True)
        self.assertEqual(
            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]], list(batch_dp)
        )

        # Functional Test: nested batching
        batch_dp_2 = batch_dp.batch(batch_size=3)
        self.assertEqual(
            [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], list(batch_dp_2)
        )

        # Reset Test:
        n_elements_before_reset = 3
        res_before_reset, res_after_reset = reset_after_n_next_calls(
            batch_dp, n_elements_before_reset
        )
        self.assertEqual([[0, 1], [2, 3], [4, 5]], res_before_reset)
        self.assertEqual(
            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]], res_after_reset
        )

        # __len__ Test:
        self.assertEqual(6, len(batch_dp))
        self.assertEqual(2, len(batch_dp_2))


# Metaclass conflict for Python 3.6
# Multiple inheritance with NamedTuple is not supported for Python 3.9
_generic_namedtuple_allowed = sys.version_info >= (3, 7) and sys.version_info < (3, 9)
if _generic_namedtuple_allowed:

    class InvalidData(NamedTuple, Generic[T_co]):
        name: str
        data: T_co


class TestTyping(TestCase):
    def test_isinstance(self):
        class A(IterDataPipe):
            pass

        class B(IterDataPipe):
            pass

        a = A()
        self.assertTrue(isinstance(a, A))
        self.assertFalse(isinstance(a, B))

    def test_protocol(self):
        try:
            from typing import Protocol  # type: ignore[attr-defined]
        except ImportError:
            from typing import _Protocol  # type: ignore[attr-defined]

            Protocol = _Protocol

        class P(Protocol):
            pass

        class A(IterDataPipe[P]):
            pass

    @skipTyping
    def test_subtype(self):
        from torch.utils.data.datapipes._typing import issubtype

        basic_type = (int, str, bool, float, complex, list, tuple, dict, set, T_co)
        for t in basic_type:
            self.assertTrue(issubtype(t, t))
            self.assertTrue(issubtype(t, Any))
            if t == T_co:
                self.assertTrue(issubtype(Any, t))
            else:
                self.assertFalse(issubtype(Any, t))
        for t1, t2 in itertools.product(basic_type, basic_type):
            if t1 == t2 or t2 == T_co:
                self.assertTrue(issubtype(t1, t2))
            else:
                self.assertFalse(issubtype(t1, t2))

        T = TypeVar("T", int, str)
        S = TypeVar("S", bool, Union[str, int], Tuple[int, T])  # type: ignore[valid-type]
        types = (
            (int, Optional[int]),
            (List, Union[int, list]),
            (Tuple[int, str], S),
            (Tuple[int, str], tuple),
            (T, S),
            (S, T_co),
            (T, Union[S, Set]),
        )
        for sub, par in types:
            self.assertTrue(issubtype(sub, par))
            self.assertFalse(issubtype(par, sub))

        subscriptable_types = {
            List: 1,
            Tuple: 2,  # use 2 parameters
            Set: 1,
            Dict: 2,
        }
        for subscript_type, n in subscriptable_types.items():
            for ts in itertools.combinations(types, n):
                subs, pars = zip(*ts)
                sub = subscript_type[subs]  # type: ignore[index]
                par = subscript_type[pars]  # type: ignore[index]
                self.assertTrue(issubtype(sub, par))
                self.assertFalse(issubtype(par, sub))
                # Non-recursive check
                self.assertTrue(issubtype(par, sub, recursive=False))

    @skipTyping
    def test_issubinstance(self):
        from torch.utils.data.datapipes._typing import issubinstance

        basic_data = (1, "1", True, 1.0, complex(1.0, 0.0))
        basic_type = (int, str, bool, float, complex)
        S = TypeVar("S", bool, Union[str, int])
        for d in basic_data:
            self.assertTrue(issubinstance(d, Any))
            self.assertTrue(issubinstance(d, T_co))
            if type(d) in (bool, int, str):
                self.assertTrue(issubinstance(d, S))
            else:
                self.assertFalse(issubinstance(d, S))
            for t in basic_type:
                if type(d) == t:
                    self.assertTrue(issubinstance(d, t))
                else:
                    self.assertFalse(issubinstance(d, t))
        # list/set
        dt = (([1, "1", 2], List), (set({1, "1", 2}), Set))
        for d, t in dt:
            self.assertTrue(issubinstance(d, t))
            self.assertTrue(issubinstance(d, t[T_co]))  # type: ignore[index]
            self.assertFalse(issubinstance(d, t[int]))  # type: ignore[index]

        # dict
        d = {"1": 1, "2": 2.0}
        self.assertTrue(issubinstance(d, Dict))
        self.assertTrue(issubinstance(d, Dict[str, T_co]))
        self.assertFalse(issubinstance(d, Dict[str, int]))

        # tuple
        d = (1, "1", 2)
        self.assertTrue(issubinstance(d, Tuple))
        self.assertTrue(issubinstance(d, Tuple[int, str, T_co]))
        self.assertFalse(issubinstance(d, Tuple[int, Any]))
        self.assertFalse(issubinstance(d, Tuple[int, int, int]))

    # Static checking annotation
    @skipTyping
    def test_compile_time(self):
        with self.assertRaisesRegex(TypeError, r"Expected 'Iterator' as the return"):

            class InvalidDP1(IterDataPipe[int]):
                def __iter__(self) -> str:  # type: ignore[misc, override]
                    yield 0

        with self.assertRaisesRegex(TypeError, r"Expected return type of '__iter__'"):

            class InvalidDP2(IterDataPipe[Tuple]):
                def __iter__(self) -> Iterator[int]:  # type: ignore[override]
                    yield 0

        with self.assertRaisesRegex(TypeError, r"Expected return type of '__iter__'"):

            class InvalidDP3(IterDataPipe[Tuple[int, str]]):
                def __iter__(self) -> Iterator[tuple]:  # type: ignore[override]
                    yield (0,)

        if _generic_namedtuple_allowed:
            with self.assertRaisesRegex(
                TypeError, r"is not supported by Python typing"
            ):

                class InvalidDP4(IterDataPipe["InvalidData[int]"]):  # type: ignore[type-arg, misc]
                    pass

        class DP1(IterDataPipe[Tuple[int, str]]):
            def __init__(self, length):
                self.length = length

            def __iter__(self) -> Iterator[Tuple[int, str]]:
                for d in range(self.length):
                    yield d, str(d)

        self.assertTrue(issubclass(DP1, IterDataPipe))
        dp1 = DP1(10)
        self.assertTrue(DP1.type.issubtype(dp1.type) and dp1.type.issubtype(DP1.type))  # type: ignore[attr-defined]
        dp1_ = DP1(5)
        self.assertEqual(dp1.type, dp1_.type)

        with self.assertRaisesRegex(TypeError, r"is not a generic class"):

            class InvalidDP5(DP1[tuple]):  # type: ignore[type-arg]
                def __iter__(self) -> Iterator[tuple]:  # type: ignore[override]
                    yield (0,)

        class DP2(IterDataPipe[T_co]):
            def __iter__(self) -> Iterator[T_co]:
                yield from range(10)  # type: ignore[misc]

        self.assertTrue(issubclass(DP2, IterDataPipe))
        dp2 = DP2()  # type: ignore[var-annotated]
        self.assertTrue(DP2.type.issubtype(dp2.type) and dp2.type.issubtype(DP2.type))  # type: ignore[attr-defined]
        dp2_ = DP2()  # type: ignore[var-annotated]
        self.assertEqual(dp2.type, dp2_.type)

        class DP3(IterDataPipe[Tuple[T_co, str]]):
            r"""DataPipe without fixed type with __init__ function"""

            def __init__(self, datasource):
                self.datasource = datasource

            def __iter__(self) -> Iterator[Tuple[T_co, str]]:
                for d in self.datasource:
                    yield d, str(d)

        self.assertTrue(issubclass(DP3, IterDataPipe))
        dp3 = DP3(range(10))  # type: ignore[var-annotated]
        self.assertTrue(DP3.type.issubtype(dp3.type) and dp3.type.issubtype(DP3.type))  # type: ignore[attr-defined]
        dp3_ = DP3(5)  # type: ignore[var-annotated]
        self.assertEqual(dp3.type, dp3_.type)

        class DP4(IterDataPipe[tuple]):
            r"""DataPipe without __iter__ annotation"""

            def __iter__(self):
                raise NotImplementedError

        self.assertTrue(issubclass(DP4, IterDataPipe))
        dp4 = DP4()
        self.assertTrue(dp4.type.param == tuple)

        class DP5(IterDataPipe):
            r"""DataPipe without type annotation"""

            def __iter__(self) -> Iterator[str]:
                raise NotImplementedError

        self.assertTrue(issubclass(DP5, IterDataPipe))
        dp5 = DP5()
        from torch.utils.data.datapipes._typing import issubtype

        self.assertTrue(
            issubtype(dp5.type.param, Any) and issubtype(Any, dp5.type.param)
        )

        class DP6(IterDataPipe[int]):
            r"""DataPipe with plain Iterator"""

            def __iter__(self) -> Iterator:
                raise NotImplementedError

        self.assertTrue(issubclass(DP6, IterDataPipe))
        dp6 = DP6()
        self.assertTrue(dp6.type.param == int)

        class DP7(IterDataPipe[Awaitable[T_co]]):
            r"""DataPipe with abstract base class"""

        self.assertTrue(issubclass(DP7, IterDataPipe))
        self.assertTrue(DP7.type.param == Awaitable[T_co])  # type: ignore[attr-defined]

        class DP8(DP7[str]):
            r"""DataPipe subclass from a DataPipe with abc type"""

        self.assertTrue(issubclass(DP8, IterDataPipe))
        self.assertTrue(DP8.type.param == Awaitable[str])  # type: ignore[attr-defined]

    @skipTyping
    def test_construct_time(self):
        class DP0(IterDataPipe[Tuple]):
            @argument_validation
            def __init__(self, dp: IterDataPipe):
                self.dp = dp

            def __iter__(self) -> Iterator[Tuple]:
                for d in self.dp:
                    yield d, str(d)

        class DP1(IterDataPipe[int]):
            @argument_validation
            def __init__(self, dp: IterDataPipe[Tuple[int, str]]):
                self.dp = dp

            def __iter__(self) -> Iterator[int]:
                for a, b in self.dp:
                    yield a

        # Non-DataPipe input with DataPipe hint
        datasource = [(1, "1"), (2, "2"), (3, "3")]
        with self.assertRaisesRegex(
            TypeError, r"Expected argument 'dp' as a IterDataPipe"
        ):
            dp0 = DP0(datasource)

        dp0 = DP0(dp.iter.IterableWrapper(range(10)))
        with self.assertRaisesRegex(
            TypeError, r"Expected type of argument 'dp' as a subtype"
        ):
            dp1 = DP1(dp0)

    @skipTyping
    def test_runtime(self):
        class DP(IterDataPipe[Tuple[int, T_co]]):
            def __init__(self, datasource):
                self.ds = datasource

            @runtime_validation
            def __iter__(self) -> Iterator[Tuple[int, T_co]]:
                yield from self.ds

        dss = ([(1, "1"), (2, "2")], [(1, 1), (2, "2")])
        for ds in dss:
            dp0 = DP(ds)  # type: ignore[var-annotated]
            self.assertEqual(list(dp0), ds)
            # Reset __iter__
            self.assertEqual(list(dp0), ds)

        dss = (
            [(1, 1), ("2", 2)],  # type: ignore[assignment, list-item]
            [[1, "1"], [2, "2"]],  # type: ignore[list-item]
            [1, "1", 2, "2"],
        )
        for ds in dss:
            dp0 = DP(ds)
            with self.assertRaisesRegex(
                RuntimeError, r"Expected an instance as subtype"
            ):
                list(dp0)

            with runtime_validation_disabled():
                self.assertEqual(list(dp0), ds)
                with runtime_validation_disabled():
                    self.assertEqual(list(dp0), ds)

            with self.assertRaisesRegex(
                RuntimeError, r"Expected an instance as subtype"
            ):
                list(dp0)

    @skipTyping
    def test_reinforce(self):
        T = TypeVar("T", int, str)

        class DP(IterDataPipe[T]):
            def __init__(self, ds):
                self.ds = ds

            @runtime_validation
            def __iter__(self) -> Iterator[T]:
                yield from self.ds

        ds = list(range(10))
        # Valid type reinforcement
        dp0 = DP(ds).reinforce_type(int)
        self.assertTrue(dp0.type, int)
        self.assertEqual(list(dp0), ds)

        # Invalid type
        with self.assertRaisesRegex(TypeError, r"'expected_type' must be a type"):
            dp1 = DP(ds).reinforce_type(1)

        # Type is not subtype
        with self.assertRaisesRegex(
            TypeError, r"Expected 'expected_type' as subtype of"
        ):
            dp2 = DP(ds).reinforce_type(float)

        # Invalid data at runtime
        dp3 = DP(ds).reinforce_type(str)
        with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"):
            list(dp3)

        # Context Manager to disable the runtime validation
        with runtime_validation_disabled():
            self.assertEqual(list(dp3), ds)


class NumbersDataset(IterDataPipe):
    def __init__(self, size=10):
        self.size = size

    def __iter__(self):
        yield from range(self.size)

    def __len__(self):
        return self.size


class TestGraph(TestCase):
    class CustomIterDataPipe(IterDataPipe):
        def add_v(self, x):
            return x + self.v

        def __init__(self, source_dp, v=1):
            self._dp = source_dp.map(self.add_v)
            self.v = 1

        def __iter__(self):
            yield from self._dp

        def __hash__(self):
            raise NotImplementedError

    def test_simple_traverse(self):
        numbers_dp = NumbersDataset(size=50)
        shuffled_dp = numbers_dp.shuffle()
        sharded_dp = shuffled_dp.sharding_filter()
        mapped_dp = sharded_dp.map(lambda x: x * 10)
        graph = traverse_dps(mapped_dp)
        expected: Dict[Any, Any] = {
            id(mapped_dp): (
                mapped_dp,
                {
                    id(sharded_dp): (
                        sharded_dp,
                        {
                            id(shuffled_dp): (
                                shuffled_dp,
                                {id(numbers_dp): (numbers_dp, {})},
                            )
                        },
                    )
                },
            )
        }
        self.assertEqual(expected, graph)

        dps = torch.utils.data.graph_settings.get_all_graph_pipes(graph)
        self.assertEqual(len(dps), 4)
        for datapipe in (numbers_dp, shuffled_dp, sharded_dp, mapped_dp):
            self.assertTrue(datapipe in dps)

    def test_traverse_forked(self):
        numbers_dp = NumbersDataset(size=50)
        dp0, dp1, dp2 = numbers_dp.fork(num_instances=3)
        dp0_upd = dp0.map(lambda x: x * 10)
        dp1_upd = dp1.filter(lambda x: x % 3 == 1)
        combined_dp = dp0_upd.mux(dp1_upd, dp2)
        graph = traverse_dps(combined_dp)
        expected = {
            id(combined_dp): (
                combined_dp,
                {
                    id(dp0_upd): (
                        dp0_upd,
                        {
                            id(dp0): (
                                dp0,
                                {
                                    id(dp0.main_datapipe): (
                                        dp0.main_datapipe,
                                        {
                                            id(dp0.main_datapipe.main_datapipe): (
                                                dp0.main_datapipe.main_datapipe,
                                                {},
                                            )
                                        },
                                    )
                                },
                            )
                        },
                    ),
                    id(dp1_upd): (
                        dp1_upd,
                        {
                            id(dp1): (
                                dp1,
                                {
                                    id(dp1.main_datapipe): (
                                        dp1.main_datapipe,
                                        {
                                            id(dp1.main_datapipe.main_datapipe): (
                                                dp1.main_datapipe.main_datapipe,
                                                {},
                                            )
                                        },
                                    )
                                },
                            )
                        },
                    ),
                    id(dp2): (
                        dp2,
                        {
                            id(dp2.main_datapipe): (
                                dp2.main_datapipe,
                                {
                                    id(dp2.main_datapipe.main_datapipe): (
                                        dp2.main_datapipe.main_datapipe,
                                        {},
                                    )
                                },
                            )
                        },
                    ),
                },
            )
        }
        self.assertEqual(expected, graph)

        dps = torch.utils.data.graph_settings.get_all_graph_pipes(graph)
        self.assertEqual(len(dps), 8)
        for _dp in [
            numbers_dp,
            dp0.main_datapipe,
            dp0,
            dp1,
            dp2,
            dp0_upd,
            dp1_upd,
            combined_dp,
        ]:
            self.assertTrue(_dp in dps)

    def test_traverse_mapdatapipe(self):
        source_dp = dp.map.SequenceWrapper(range(10))
        map_dp = source_dp.map(partial(_fake_add, 1))
        graph = traverse_dps(map_dp)
        expected: Dict[Any, Any] = {
            id(map_dp): (map_dp, {id(source_dp): (source_dp, {})})
        }
        self.assertEqual(expected, graph)

    def test_traverse_mixdatapipe(self):
        source_map_dp = dp.map.SequenceWrapper(range(10))
        iter_dp = dp.iter.IterableWrapper(source_map_dp)
        graph = traverse_dps(iter_dp)
        expected: Dict[Any, Any] = {
            id(iter_dp): (iter_dp, {id(source_map_dp): (source_map_dp, {})})
        }
        self.assertEqual(expected, graph)

    def test_traverse_circular_datapipe(self):
        source_iter_dp = dp.iter.IterableWrapper(list(range(10)))
        circular_dp = TestGraph.CustomIterDataPipe(source_iter_dp)
        graph = traverse_dps(circular_dp)
        # See issue: https://github.com/pytorch/data/issues/535
        expected: Dict[Any, Any] = {
            id(circular_dp): (
                circular_dp,
                {
                    id(circular_dp._dp): (
                        circular_dp._dp,
                        {id(source_iter_dp): (source_iter_dp, {})},
                    )
                },
            )
        }
        self.assertEqual(expected, graph)

        dps = torch.utils.data.graph_settings.get_all_graph_pipes(graph)
        self.assertEqual(len(dps), 3)
        for _dp in [circular_dp, circular_dp._dp, source_iter_dp]:
            self.assertTrue(_dp in dps)

    def test_traverse_unhashable_datapipe(self):
        source_iter_dp = dp.iter.IterableWrapper(list(range(10)))
        unhashable_dp = TestGraph.CustomIterDataPipe(source_iter_dp)
        graph = traverse_dps(unhashable_dp)
        with self.assertRaises(NotImplementedError):
            hash(unhashable_dp)
        expected: Dict[Any, Any] = {
            id(unhashable_dp): (
                unhashable_dp,
                {
                    id(unhashable_dp._dp): (
                        unhashable_dp._dp,
                        {id(source_iter_dp): (source_iter_dp, {})},
                    )
                },
            )
        }
        self.assertEqual(expected, graph)


def unbatch(x):
    return x[0]


class TestSerialization(TestCase):
    @skipIfNoDill
    def test_spawn_lambdas_iter(self):
        idp = dp.iter.IterableWrapper(range(3)).map(lambda x: x + 1).shuffle()
        dl = DataLoader(
            idp,
            num_workers=2,
            shuffle=True,
            multiprocessing_context="spawn",
            collate_fn=unbatch,
            batch_size=1,
        )
        result = list(dl)
        self.assertEqual([1, 1, 2, 2, 3, 3], sorted(result))

    @skipIfNoDill
    def test_spawn_lambdas_map(self):
        mdp = dp.map.SequenceWrapper(range(3)).map(lambda x: x + 1).shuffle()
        dl = DataLoader(
            mdp,
            num_workers=2,
            shuffle=True,
            multiprocessing_context="spawn",
            collate_fn=unbatch,
            batch_size=1,
        )
        result = list(dl)
        self.assertEqual([1, 1, 2, 2, 3, 3], sorted(result))


class TestCircularSerialization(TestCase):
    class CustomIterDataPipe(IterDataPipe):
        @staticmethod
        def add_one(x):
            return x + 1

        @classmethod
        def classify(cls, x):
            return 0

        def add_v(self, x):
            return x + self.v

        def __init__(self, fn, source_dp=None):
            self.fn = fn
            self.source_dp = (
                source_dp if source_dp else dp.iter.IterableWrapper([1, 2, 4])
            )
            self._dp = (
                self.source_dp.map(self.add_one)
                .map(self.add_v)
                .demux(2, self.classify)[0]
            )
            self.v = 1

        def __iter__(self):
            yield from self._dp

    def test_circular_serialization_with_pickle(self):
        # Test for circular reference issue with pickle
        dp1 = TestCircularSerialization.CustomIterDataPipe(fn=_fake_fn)
        self.assertTrue(list(dp1) == list(pickle.loads(pickle.dumps(dp1))))

        child_1 = dp1._dp
        dm_1 = child_1.main_datapipe
        m2_1 = dm_1.main_datapipe
        m1_1 = m2_1.datapipe
        src_1 = m1_1.datapipe

        res1 = traverse_dps(dp1)
        exp_res_1 = {
            id(dp1): (
                dp1,
                {
                    id(src_1): (src_1, {}),
                    id(child_1): (
                        child_1,
                        {
                            id(dm_1): (
                                dm_1,
                                {
                                    id(m2_1): (
                                        m2_1,
                                        {id(m1_1): (m1_1, {id(src_1): (src_1, {})})},
                                    )
                                },
                            )
                        },
                    ),
                },
            )
        }
        self.assertEqual(res1, exp_res_1)
        dp2 = TestCircularSerialization.CustomIterDataPipe(fn=_fake_fn, source_dp=dp1)
        self.assertTrue(list(dp2) == list(pickle.loads(pickle.dumps(dp2))))

        child_2 = dp2._dp
        dm_2 = child_2.main_datapipe
        m2_2 = dm_2.main_datapipe
        m1_2 = m2_2.datapipe

        res2 = traverse_dps(dp2)
        exp_res_2 = {
            id(dp2): (
                dp2,
                {
                    id(dp1): (
                        dp1,
                        {
                            id(src_1): (src_1, {}),
                            id(child_1): (
                                child_1,
                                {
                                    id(dm_1): (
                                        dm_1,
                                        {
                                            id(m2_1): (
                                                m2_1,
                                                {
                                                    id(m1_1): (
                                                        m1_1,
                                                        {id(src_1): (src_1, {})},
                                                    )
                                                },
                                            )
                                        },
                                    )
                                },
                            ),
                        },
                    ),
                    id(child_2): (
                        child_2,
                        {
                            id(dm_2): (
                                dm_2,
                                {
                                    id(m2_2): (
                                        m2_2,
                                        {
                                            id(m1_2): (
                                                m1_2,
                                                {
                                                    id(dp1): (
                                                        dp1,
                                                        {
                                                            id(src_1): (src_1, {}),
                                                            id(child_1): (
                                                                child_1,
                                                                {
                                                                    id(dm_1): (
                                                                        dm_1,
                                                                        {
                                                                            id(m2_1): (
                                                                                m2_1,
                                                                                {
                                                                                    id(
                                                                                        m1_1
                                                                                    ): (
                                                                                        m1_1,
                                                                                        {
                                                                                            id(
                                                                                                src_1
                                                                                            ): (
                                                                                                src_1,
                                                                                                {},
                                                                                            )
                                                                                        },
                                                                                    )
                                                                                },
                                                                            )
                                                                        },
                                                                    )
                                                                },
                                                            ),
                                                        },
                                                    ),
                                                },
                                            )
                                        },
                                    )
                                },
                            )
                        },
                    ),
                },
            )
        }
        self.assertEqual(res2, exp_res_2)

    class LambdaIterDataPipe(CustomIterDataPipe):
        def __init__(self, fn, source_dp=None):
            super().__init__(fn, source_dp)
            self.container = [
                lambda x: x + 1,
            ]
            self.lambda_fn = lambda x: x + 1
            self._dp = (
                self.source_dp.map(self.add_one)
                .map(self.lambda_fn)
                .map(self.add_v)
                .demux(2, self.classify)[0]
            )

    @skipIfNoDill
    @skipIf(True, "Dill Tests")
    def test_circular_serialization_with_dill(self):
        # Test for circular reference issue with dill
        dp1 = TestCircularSerialization.LambdaIterDataPipe(lambda x: x + 1)
        self.assertTrue(list(dp1) == list(dill.loads(dill.dumps(dp1))))

        child_1 = dp1._dp
        dm_1 = child_1.main_datapipe
        m2_1 = dm_1.main_datapipe
        m1_1 = m2_1.datapipe
        src_1 = m1_1.datapipe

        res1 = traverse_dps(dp1)

        exp_res_1 = {
            id(dp1): (
                dp1,
                {
                    id(src_1): (src_1, {}),
                    id(child_1): (
                        child_1,
                        {
                            id(dm_1): (
                                dm_1,
                                {
                                    id(m2_1): (
                                        m2_1,
                                        {id(m1_1): (m1_1, {id(src_1): (src_1, {})})},
                                    )
                                },
                            )
                        },
                    ),
                },
            )
        }

        self.assertEqual(res1, exp_res_1)

        dp2 = TestCircularSerialization.LambdaIterDataPipe(fn=_fake_fn, source_dp=dp1)
        self.assertTrue(list(dp2) == list(dill.loads(dill.dumps(dp2))))

        child_2 = dp2._dp
        dm_2 = child_2.main_datapipe
        m2_2 = dm_2.main_datapipe
        m1_2 = m2_2.datapipe

        res2 = traverse_dps(dp2)
        exp_res_2 = {
            id(dp2): (
                dp2,
                {
                    id(dp1): (
                        dp1,
                        {
                            id(src_1): (src_1, {}),
                            id(child_1): (
                                child_1,
                                {
                                    id(dm_1): (
                                        dm_1,
                                        {
                                            id(m2_1): (
                                                m2_1,
                                                {
                                                    id(m1_1): (
                                                        m1_1,
                                                        {id(src_1): (src_1, {})},
                                                    )
                                                },
                                            )
                                        },
                                    )
                                },
                            ),
                        },
                    ),
                    id(child_2): (
                        child_2,
                        {
                            id(dm_2): (
                                dm_2,
                                {
                                    id(m2_2): (
                                        m2_2,
                                        {
                                            id(m1_2): (
                                                m1_2,
                                                {
                                                    id(dp1): (
                                                        dp1,
                                                        {
                                                            id(src_1): (src_1, {}),
                                                            id(child_1): (
                                                                child_1,
                                                                {
                                                                    id(dm_1): (
                                                                        dm_1,
                                                                        {
                                                                            id(m2_1): (
                                                                                m2_1,
                                                                                {
                                                                                    id(
                                                                                        m1_1
                                                                                    ): (
                                                                                        m1_1,
                                                                                        {
                                                                                            id(
                                                                                                src_1
                                                                                            ): (
                                                                                                src_1,
                                                                                                {},
                                                                                            )
                                                                                        },
                                                                                    )
                                                                                },
                                                                            )
                                                                        },
                                                                    )
                                                                },
                                                            ),
                                                        },
                                                    ),
                                                },
                                            )
                                        },
                                    )
                                },
                            )
                        },
                    ),
                },
            )
        }
        self.assertEqual(res2, exp_res_2)


class CustomShardingIterDataPipe(IterDataPipe):
    def __init__(self, dp):
        self.dp = dp
        self.num_of_instances = 1
        self.instance_id = 0

    def apply_sharding(self, num_of_instances, instance_id):
        self.num_of_instances = num_of_instances
        self.instance_id = instance_id

    def __iter__(self):
        for i, d in enumerate(self.dp):
            if i % self.num_of_instances == self.instance_id:
                yield d


class TestSharding(TestCase):
    def _get_pipeline(self):
        numbers_dp = NumbersDataset(size=10)
        dp0, dp1 = numbers_dp.fork(num_instances=2)
        dp0_upd = dp0.map(_mul_10)
        dp1_upd = dp1.filter(_mod_3_test)
        combined_dp = dp0_upd.mux(dp1_upd)
        return combined_dp

    def _get_dill_pipeline(self):
        numbers_dp = NumbersDataset(size=10)
        dp0, dp1 = numbers_dp.fork(num_instances=2)
        dp0_upd = dp0.map(lambda x: x * 10)
        dp1_upd = dp1.filter(lambda x: x % 3 == 1)
        combined_dp = dp0_upd.mux(dp1_upd)
        return combined_dp

    def test_simple_sharding(self):
        sharded_dp = self._get_pipeline().sharding_filter()
        torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, 1)
        items = list(sharded_dp)
        self.assertEqual([1, 20], items)

        all_items = [0, 1, 10, 4, 20, 7]
        items = []
        for i in range(3):
            sharded_dp = self._get_pipeline().sharding_filter()
            torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, i)
            items += list(sharded_dp)
        self.assertEqual(sorted(all_items), sorted(items))

    def test_sharding_groups(self):
        def construct_sharded_pipe():
            sharding_pipes = []
            dp = NumbersDataset(size=90)
            dp = dp.sharding_filter(
                sharding_group_filter=SHARDING_PRIORITIES.DISTRIBUTED
            )
            sharding_pipes.append(dp)
            dp = dp.sharding_filter(
                sharding_group_filter=SHARDING_PRIORITIES.MULTIPROCESSING
            )
            sharding_pipes.append(dp)
            dp = dp.sharding_filter(sharding_group_filter=300)
            sharding_pipes.append(dp)
            return dp, sharding_pipes

        dp, sharding_pipes = construct_sharded_pipe()

        for pipe in sharding_pipes:
            pipe.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DISTRIBUTED)
            pipe.apply_sharding(
                5, 3, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING
            )
            pipe.apply_sharding(3, 1, sharding_group=300)

        actual = list(dp)
        expected = [17, 47, 77]
        self.assertEqual(expected, actual)
        self.assertEqual(3, len(dp))

        dp, _ = construct_sharded_pipe()
        dp.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DEFAULT)
        with self.assertRaises(Exception):
            dp.apply_sharding(5, 3, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING)

        dp, _ = construct_sharded_pipe()
        dp.apply_sharding(5, 3, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING)
        with self.assertRaises(Exception):
            dp.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DEFAULT)

    # Test tud.datapipes.iter.grouping.SHARDING_PRIORITIES for backward compatbility
    # TODO: Remove this test once tud.datapipes.iter.grouping.SHARDING_PRIORITIES is deprecated
    def test_sharding_groups_in_legacy_grouping_package(self):
        with self.assertWarnsRegex(
            FutureWarning,
            r"Please use `SHARDING_PRIORITIES` "
            "from the `torch.utils.data.datapipes.iter.sharding`",
        ):
            from torch.utils.data.datapipes.iter.grouping import (
                SHARDING_PRIORITIES as LEGACY_SHARDING_PRIORITIES,
            )

        def construct_sharded_pipe():
            sharding_pipes = []
            dp = NumbersDataset(size=90)
            dp = dp.sharding_filter(
                sharding_group_filter=LEGACY_SHARDING_PRIORITIES.DISTRIBUTED
            )
            sharding_pipes.append(dp)
            dp = dp.sharding_filter(
                sharding_group_filter=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
            )
            sharding_pipes.append(dp)
            dp = dp.sharding_filter(sharding_group_filter=300)
            sharding_pipes.append(dp)
            return dp, sharding_pipes

        dp, sharding_pipes = construct_sharded_pipe()

        for pipe in sharding_pipes:
            pipe.apply_sharding(
                2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DISTRIBUTED
            )
            pipe.apply_sharding(
                5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
            )
            pipe.apply_sharding(3, 1, sharding_group=300)

        actual = list(dp)
        expected = [17, 47, 77]
        self.assertEqual(expected, actual)
        self.assertEqual(3, len(dp))

        dp, _ = construct_sharded_pipe()
        dp.apply_sharding(2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DEFAULT)
        with self.assertRaises(Exception):
            dp.apply_sharding(
                5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
            )

        dp, _ = construct_sharded_pipe()
        dp.apply_sharding(
            5, 3, sharding_group=LEGACY_SHARDING_PRIORITIES.MULTIPROCESSING
        )
        with self.assertRaises(Exception):
            dp.apply_sharding(2, 1, sharding_group=LEGACY_SHARDING_PRIORITIES.DEFAULT)

    def test_legacy_custom_sharding(self):
        dp = self._get_pipeline()
        sharded_dp = CustomShardingIterDataPipe(dp)
        torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, 1)
        items = list(sharded_dp)
        self.assertEqual([1, 20], items)

    def test_sharding_length(self):
        numbers_dp = dp.iter.IterableWrapper(range(13))
        sharded_dp0 = numbers_dp.sharding_filter()
        torch.utils.data.graph_settings.apply_sharding(sharded_dp0, 3, 0)
        sharded_dp1 = numbers_dp.sharding_filter()
        torch.utils.data.graph_settings.apply_sharding(sharded_dp1, 3, 1)
        sharded_dp2 = numbers_dp.sharding_filter()
        torch.utils.data.graph_settings.apply_sharding(sharded_dp2, 3, 2)
        self.assertEqual(13, len(numbers_dp))
        self.assertEqual(5, len(sharded_dp0))
        self.assertEqual(4, len(sharded_dp1))
        self.assertEqual(4, len(sharded_dp2))

        numbers_dp = dp.iter.IterableWrapper(range(1))
        sharded_dp0 = numbers_dp.sharding_filter()
        torch.utils.data.graph_settings.apply_sharding(sharded_dp0, 2, 0)
        sharded_dp1 = numbers_dp.sharding_filter()
        torch.utils.data.graph_settings.apply_sharding(sharded_dp1, 2, 1)
        self.assertEqual(1, len(sharded_dp0))
        self.assertEqual(0, len(sharded_dp1))

    def test_old_dataloader(self):
        dp0 = self._get_pipeline()
        expected = list(dp0)

        dp0 = self._get_pipeline().sharding_filter()
        dl = DataLoader(dp0, batch_size=1, shuffle=False, num_workers=2)
        items = list(dl)

        self.assertEqual(sorted(expected), sorted(items))

    def test_legacy_custom_sharding_with_old_dataloader(self):
        dp0 = self._get_pipeline()
        expected = list(dp0)

        dp0 = self._get_pipeline()
        dp0 = CustomShardingIterDataPipe(dp0)
        dl = DataLoader(dp0, batch_size=1, shuffle=False, num_workers=2)
        items = list(dl)

        self.assertEqual(sorted(expected), sorted(items))

    def test_multi_sharding(self):
        # Raises Error when multiple sharding on the single branch
        numbers_dp = dp.iter.IterableWrapper(range(13))
        sharded_dp = numbers_dp.sharding_filter()
        sharded_dp = sharded_dp.sharding_filter()
        with self.assertRaisesRegex(
            RuntimeError, "Sharding twice on a single pipeline"
        ):
            torch.utils.data.graph_settings.apply_sharding(sharded_dp, 3, 0)

        # Raises Error when sharding on both data source and branch
        numbers_dp = dp.iter.IterableWrapper(range(13)).sharding_filter()
        dp1, dp2 = numbers_dp.fork(2)
        sharded_dp = dp1.sharding_filter()
        zip_dp = dp2.zip(sharded_dp)
        with self.assertRaisesRegex(
            RuntimeError, "Sharding twice on a single pipeline"
        ):
            torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0)

        # Raises Error when multiple sharding on the branch and end
        numbers_dp = dp.iter.IterableWrapper(range(13))
        dp1, dp2 = numbers_dp.fork(2)
        sharded_dp = dp1.sharding_filter()
        zip_dp = dp2.zip(sharded_dp).sharding_filter()
        with self.assertRaisesRegex(
            RuntimeError, "Sharding twice on a single pipeline"
        ):
            torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0)

        # Single sharding_filter on data source
        numbers_dp = dp.iter.IterableWrapper(range(13)).sharding_filter()
        dp1, dp2 = numbers_dp.fork(2)
        zip_dp = dp1.zip(dp2)
        torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0)
        self.assertEqual(list(zip_dp), [(i * 3, i * 3) for i in range(13 // 3 + 1)])

        # Single sharding_filter per branch
        numbers_dp = dp.iter.IterableWrapper(range(13))
        dp1, dp2 = numbers_dp.fork(2)
        sharded_dp1 = dp1.sharding_filter()
        sharded_dp2 = dp2.sharding_filter()
        zip_dp = sharded_dp1.zip(sharded_dp2)
        torch.utils.data.graph_settings.apply_sharding(zip_dp, 3, 0)
        self.assertEqual(list(zip_dp), [(i * 3, i * 3) for i in range(13 // 3 + 1)])


class TestIterDataPipeSingletonConstraint(TestCase):
    r"""
    Each `IterDataPipe` can only have one active iterator. Whenever a new iterator is created, older
    iterators are invalidated. These tests aim to ensure `IterDataPipe` follows this behavior.
    """

    def _check_single_iterator_invalidation_logic(self, source_dp: IterDataPipe):
        r"""
        Given a IterDataPipe, verifies that the iterator can be read, reset, and the creation of
        a second iterator invalidates the first one.
        """
        it1 = iter(source_dp)
        self.assertEqual(list(range(10)), list(it1))
        it1 = iter(source_dp)
        self.assertEqual(
            list(range(10)), list(it1)
        )  # A fresh iterator can be read in full again
        it1 = iter(source_dp)
        self.assertEqual(0, next(it1))
        it2 = iter(source_dp)  # This should invalidate `it1`
        self.assertEqual(0, next(it2))  # Should read from the beginning again
        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
            next(it1)

    def test_iterdatapipe_singleton_generator(self):
        r"""
        Testing for the case where IterDataPipe's `__iter__` is a generator function.
        """

        # Functional Test: Check if invalidation logic is correct
        source_dp: IterDataPipe = dp.iter.IterableWrapper(range(10))
        self._check_single_iterator_invalidation_logic(source_dp)

        # Functional Test: extend the test to a pipeline
        dps = source_dp.map(_fake_fn).filter(_fake_filter_fn)
        self._check_single_iterator_invalidation_logic(dps)

        # Functional Test: multiple simultaneous references to the same DataPipe fails
        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
            for _ in zip(source_dp, source_dp):
                pass

        # Function Test: sequential references work
        for _ in zip(list(source_dp), list(source_dp)):
            pass

    def test_iterdatapipe_singleton_self_next(self):
        r"""
        Testing for the case where IterDataPipe's `__iter__` returns `self` and there is a `__next__` method
        Note that the following DataPipe by is singleton by default (because `__iter__` returns `self`).
        """

        class _CustomIterDP_Self(IterDataPipe):
            def __init__(self, iterable):
                self.source = iterable
                self.iterable = iter(iterable)

            def __iter__(self):
                self.reset()
                return self

            def __next__(self):
                return next(self.iterable)

            def reset(self):
                self.iterable = iter(self.source)

        # Functional Test: Check that every `__iter__` call returns the same object
        source_dp = _CustomIterDP_Self(range(10))
        res = list(source_dp)
        it = iter(source_dp)
        self.assertEqual(res, list(it))

        # Functional Test: Check if invalidation logic is correct
        source_dp = _CustomIterDP_Self(range(10))
        self._check_single_iterator_invalidation_logic(source_dp)
        self.assertEqual(
            1, next(source_dp)
        )  # `source_dp` is still valid and can be read

        # Functional Test: extend the test to a pipeline
        source_dp = _CustomIterDP_Self(
            dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn)
        )
        self._check_single_iterator_invalidation_logic(source_dp)
        self.assertEqual(
            1, next(source_dp)
        )  # `source_dp` is still valid and can be read

        # Functional Test: multiple simultaneous references to the same DataPipe fails
        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
            for _ in zip(source_dp, source_dp):
                pass

    def test_iterdatapipe_singleton_new_object(self):
        r"""
        Testing for the case where IterDataPipe's `__iter__` isn't a generator nor returns `self`,
        and there isn't a `__next__` method.
        """

        class _CustomIterDP(IterDataPipe):
            def __init__(self, iterable):
                self.iterable = iter(iterable)

            def __iter__(self):  # Note that this doesn't reset
                return self.iterable  # Intentionally not returning `self`

        # Functional Test: Check if invalidation logic is correct
        source_dp = _CustomIterDP(range(10))
        it1 = iter(source_dp)
        self.assertEqual(0, next(it1))
        it2 = iter(source_dp)
        self.assertEqual(1, next(it2))
        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
            next(it1)

        # Functional Test: extend the test to a pipeline
        source_dp = _CustomIterDP(
            dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn)
        )
        it1 = iter(source_dp)
        self.assertEqual(0, next(it1))
        it2 = iter(source_dp)
        self.assertEqual(1, next(it2))
        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
            next(it1)

        # Functional Test: multiple simultaneous references to the same DataPipe fails
        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
            for _ in zip(source_dp, source_dp):
                pass

    def test_iterdatapipe_singleton_buggy(self):
        r"""
        Buggy test case case where IterDataPipe's `__iter__` returns a new object, but also has
        a `__next__` method.
        """

        class _CustomIterDP(IterDataPipe):
            def __init__(self, iterable):
                self.source = iterable
                self.iterable = iter(iterable)

            def __iter__(self):
                return iter(self.source)  # Intentionally not returning `self`

            def __next__(self):
                return next(self.iterable)

        # Functional Test: Check if invalidation logic is correct
        source_dp = _CustomIterDP(range(10))
        self._check_single_iterator_invalidation_logic(source_dp)
        self.assertEqual(0, next(source_dp))  # `__next__` is unrelated with `__iter__`

        # Functional Test: Special case to show `__next__` is unrelated with `__iter__`
        source_dp = _CustomIterDP(range(10))
        self.assertEqual(0, next(source_dp))
        it1 = iter(source_dp)
        self.assertEqual(0, next(it1))
        self.assertEqual(1, next(source_dp))
        it2 = iter(source_dp)  # invalidates both `it1`
        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
            next(it1)
        self.assertEqual(2, next(source_dp))  # not impacted by the creation of `it2`
        self.assertEqual(
            list(range(10)), list(it2)
        )  # `it2` still works because it is a new object

    def test_iterdatapipe_singleton_constraint_multiple_outputs(self):
        r"""
        Testing for the case where IterDataPipe has multiple child DataPipes as outputs.
        """
        # Functional Test: all previous related iterators should be invalidated when a new iterator
        #                  is created from a ChildDataPipe
        source_dp: IterDataPipe = dp.iter.IterableWrapper(range(10))
        cdp1, cdp2 = source_dp.fork(num_instances=2)
        it1, it2 = iter(cdp1), iter(cdp2)
        self.assertEqual(list(range(10)), list(it1))
        self.assertEqual(list(range(10)), list(it2))
        it1, it2 = iter(cdp1), iter(cdp2)
        with warnings.catch_warnings(record=True) as wa:
            it3 = iter(cdp1)  # This should invalidate `it1` and `it2`
            self.assertEqual(len(wa), 1)
            self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted")
        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
            next(it1)
        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
            next(it2)
        self.assertEqual(0, next(it3))
        # The next line should not invalidate anything, as there was no new iterator created
        # for `cdp2` after `it2` was invalidated
        it4 = iter(cdp2)
        self.assertEqual(1, next(it3))  # An error shouldn't be raised here
        self.assertEqual(list(range(10)), list(it4))

        # Functional Test: invalidation when a new iterator is created from `source_dp`
        source_dp = dp.iter.IterableWrapper(range(10))
        cdp1, cdp2 = source_dp.fork(num_instances=2)
        it1, it2 = iter(cdp1), iter(cdp2)
        self.assertEqual(list(range(10)), list(it1))
        self.assertEqual(list(range(10)), list(it2))
        it1, it2 = iter(cdp1), iter(cdp2)
        self.assertEqual(0, next(it1))
        self.assertEqual(0, next(it2))
        it3 = iter(source_dp)  # note that a new iterator is created from `source_dp`
        self.assertEqual(
            0, next(it3)
        )  # `it3` should invalidate `it1` and `it2` since they both use `source_dp`
        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
            next(it1)
        self.assertEqual(1, next(it3))

        # Function Test: Extending test to pipeline
        source_dp = (
            dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn)
        )
        cdp1, cdp2 = source_dp.fork(num_instances=2)
        it1, it2 = iter(cdp1), iter(cdp2)
        self.assertEqual(list(range(10)), list(it1))
        self.assertEqual(list(range(10)), list(it2))
        it1, it2 = iter(cdp1), iter(cdp2)
        with warnings.catch_warnings(record=True) as wa:
            it3 = iter(cdp1)  # This should invalidate `it1` and `it2`
            self.assertEqual(len(wa), 1)
            self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted")
        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
            next(it1)
        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
            next(it2)
        with warnings.catch_warnings(record=True) as wa:
            it1, it2 = iter(cdp1), iter(cdp2)
            self.assertEqual(len(wa), 1)
            self.assertRegex(str(wa[0].message), r"child DataPipes are not exhausted")
        self.assertEqual(0, next(it1))
        self.assertEqual(0, next(it2))
        it3 = iter(source_dp)  # note that a new iterator is created from `source_dp`
        self.assertEqual(
            0, next(it3)
        )  # `it3` should invalidate `it1` and `it2` since they both use `source_dp`
        with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
            next(it1)
        self.assertEqual(1, next(it3))


class TestIterDataPipeCountSampleYielded(TestCase):
    def _yield_count_test_helper(self, datapipe, n_expected_samples):
        # Functional Test: Check if number of samples yielded is as expected
        res = list(datapipe)
        self.assertEqual(len(res), datapipe._number_of_samples_yielded)

        # Functional Test: Check if the count is correct when DataPipe is partially read
        it = iter(datapipe)
        res = []
        for i, value in enumerate(it):
            res.append(value)
            if i == n_expected_samples - 1:
                break
        self.assertEqual(n_expected_samples, datapipe._number_of_samples_yielded)

        # Functional Test: Check for reset behavior and if iterator also works
        it = iter(datapipe)  # reset the DataPipe
        res = list(it)
        self.assertEqual(len(res), datapipe._number_of_samples_yielded)

    def test_iterdatapipe_sample_yielded_generator_function(self):
        # Functional Test: `__iter__` is a generator function
        datapipe: IterDataPipe = dp.iter.IterableWrapper(range(10))
        self._yield_count_test_helper(datapipe, n_expected_samples=5)

    def test_iterdatapipe_sample_yielded_generator_function_exception(self):
        # Functional Test: `__iter__` is a custom generator function with exception
        class _CustomGeneratorFnDataPipe(IterDataPipe):
            # This class's `__iter__` has a Runtime Error
            def __iter__(self):
                yield 0
                yield 1
                yield 2
                raise RuntimeError("Custom test error after yielding 3 elements")
                yield 3

        # Functional Test: Ensure the count is correct even when exception is raised
        datapipe: IterDataPipe = _CustomGeneratorFnDataPipe()
        with self.assertRaisesRegex(
            RuntimeError, "Custom test error after yielding 3 elements"
        ):
            list(datapipe)
        self.assertEqual(3, datapipe._number_of_samples_yielded)

        # Functional Test: Check for reset behavior and if iterator also works
        it = iter(datapipe)  # reset the DataPipe
        with self.assertRaisesRegex(
            RuntimeError, "Custom test error after yielding 3 elements"
        ):
            list(it)
        self.assertEqual(3, datapipe._number_of_samples_yielded)

    def test_iterdatapipe_sample_yielded_return_self(self):
        class _CustomGeneratorDataPipe(IterDataPipe):
            # This class's `__iter__` is not a generator function
            def __init__(self) -> None:
                self.source = iter(range(10))

            def __iter__(self):
                return self.source

            def reset(self):
                self.source = iter(range(10))

        datapipe: IterDataPipe = _CustomGeneratorDataPipe()
        self._yield_count_test_helper(datapipe, n_expected_samples=5)

    def test_iterdatapipe_sample_yielded_next(self):
        class _CustomNextDataPipe(IterDataPipe):
            # This class's `__iter__` returns `self` and has a `__next__`
            def __init__(self) -> None:
                self.source = iter(range(10))

            def __iter__(self):
                return self

            def __next__(self):
                return next(self.source)

            def reset(self):
                self.source = iter(range(10))

        datapipe: IterDataPipe = _CustomNextDataPipe()
        self._yield_count_test_helper(datapipe, n_expected_samples=5)

    def test_iterdatapipe_sample_yielded_next_exception(self):
        class _CustomNextDataPipe(IterDataPipe):
            # This class's `__iter__` returns `self` and has a `__next__`
            def __init__(self) -> None:
                self.source = iter(range(10))
                self.count = 0

            def __iter__(self):
                return self

            def __next__(self):
                if self.count == 3:
                    raise RuntimeError("Custom test error after yielding 3 elements")
                self.count += 1
                return next(self.source)

            def reset(self):
                self.count = 0
                self.source = iter(range(10))

        # Functional Test: Ensure the count is correct even when exception is raised
        datapipe: IterDataPipe = _CustomNextDataPipe()
        with self.assertRaisesRegex(
            RuntimeError, "Custom test error after yielding 3 elements"
        ):
            list(datapipe)
        self.assertEqual(3, datapipe._number_of_samples_yielded)

        # Functional Test: Check for reset behavior and if iterator also works
        it = iter(datapipe)  # reset the DataPipe
        with self.assertRaisesRegex(
            RuntimeError, "Custom test error after yielding 3 elements"
        ):
            list(it)
        self.assertEqual(3, datapipe._number_of_samples_yielded)


class _CustomNonGeneratorTestDataPipe(IterDataPipe):
    def __init__(self) -> None:
        self.n = 10
        self.source = list(range(self.n))

    # This class's `__iter__` is not a generator function
    def __iter__(self):
        return iter(self.source)

    def __len__(self):
        return self.n


class _CustomSelfNextTestDataPipe(IterDataPipe):
    def __init__(self) -> None:
        self.n = 10
        self.iter = iter(range(self.n))

    def __iter__(self):
        return self

    def __next__(self):
        return next(self.iter)

    def reset(self):
        self.iter = iter(range(self.n))

    def __len__(self):
        return self.n


class TestIterDataPipeGraphFastForward(TestCase):
    def _fast_forward_graph_test_helper(
        self, datapipe, fast_forward_fn, expected_res, n_iterations=3, rng=None
    ):
        if rng is None:
            rng = torch.Generator()
        rng = rng.manual_seed(0)
        torch.utils.data.graph_settings.apply_random_seed(datapipe, rng)

        # Test Case: fast forward works with list
        rng.manual_seed(0)
        fast_forward_fn(datapipe, n_iterations, rng)
        actual_res = list(datapipe)
        self.assertEqual(len(datapipe) - n_iterations, len(actual_res))
        self.assertEqual(expected_res[n_iterations:], actual_res)

        # Test Case: fast forward works with iterator
        rng.manual_seed(0)
        fast_forward_fn(datapipe, n_iterations, rng)
        it = iter(datapipe)
        actual_res = list(it)
        self.assertEqual(len(datapipe) - n_iterations, len(actual_res))
        self.assertEqual(expected_res[n_iterations:], actual_res)
        with self.assertRaises(StopIteration):
            next(it)

    def test_simple_snapshot_graph(self):
        graph1 = dp.iter.IterableWrapper(range(10))
        res1 = list(range(10))
        self._fast_forward_graph_test_helper(
            graph1, _simple_graph_snapshot_restoration, expected_res=res1
        )

        graph2 = graph1.map(_mul_10)
        res2 = [10 * x for x in res1]
        self._fast_forward_graph_test_helper(
            graph2, _simple_graph_snapshot_restoration, expected_res=res2
        )

        rng = torch.Generator()
        graph3 = graph2.shuffle()
        rng.manual_seed(0)
        torch.utils.data.graph_settings.apply_random_seed(graph3, rng)
        res3 = list(graph3)
        self._fast_forward_graph_test_helper(
            graph3, _simple_graph_snapshot_restoration, expected_res=res3
        )

        graph4 = graph3.map(_mul_10)
        res4 = [10 * x for x in res3]
        self._fast_forward_graph_test_helper(
            graph4, _simple_graph_snapshot_restoration, expected_res=res4
        )

        batch_size = 2
        graph5 = graph4.batch(batch_size)
        res5 = [
            res4[i : i + batch_size] for i in range(0, len(res4), batch_size)
        ]  # .batch(2)
        self._fast_forward_graph_test_helper(
            graph5, _simple_graph_snapshot_restoration, expected_res=res5
        )

        # With `fork` and `zip`
        cdp1, cdp2 = graph5.fork(2)
        graph6 = cdp1.zip(cdp2)
        rng = rng.manual_seed(100)
        torch.utils.data.graph_settings.apply_random_seed(graph6, rng)
        res6 = [(x, x) for x in res5]
        self._fast_forward_graph_test_helper(
            graph6, _simple_graph_snapshot_restoration, expected_res=res6
        )

        # With `fork` and `concat`
        graph7 = cdp1.concat(cdp2)
        res7 = res5 * 2
        self._fast_forward_graph_test_helper(
            graph7, _simple_graph_snapshot_restoration, expected_res=res7
        )

        # Raises an exception if the graph has already been restored
        with self.assertRaisesRegex(
            RuntimeError, "Snapshot restoration cannot be applied."
        ):
            _simple_graph_snapshot_restoration(graph7, 1)
            _simple_graph_snapshot_restoration(graph7, 1)

    def test_simple_snapshot_custom_non_generator(self):
        graph = _CustomNonGeneratorTestDataPipe()
        self._fast_forward_graph_test_helper(
            graph, _simple_graph_snapshot_restoration, expected_res=range(10)
        )

    def test_simple_snapshot_custom_self_next(self):
        graph = _CustomSelfNextTestDataPipe()
        self._fast_forward_graph_test_helper(
            graph, _simple_graph_snapshot_restoration, expected_res=range(10)
        )

    def _snapshot_test_helper(self, datapipe, expected_res, n_iter=3, rng=None):
        """
        Extend the previous test with serialization and deserialization test.
        """
        if rng is None:
            rng = torch.Generator()
        rng.manual_seed(0)
        torch.utils.data.graph_settings.apply_random_seed(datapipe, rng)
        it = iter(datapipe)
        for _ in range(n_iter):
            next(it)
        serialized_graph = pickle.dumps(datapipe)
        deserialized_graph = pickle.loads(serialized_graph)
        self.assertEqual(n_iter, datapipe._number_of_samples_yielded)
        self.assertEqual(n_iter, deserialized_graph._number_of_samples_yielded)

        rng_for_deserialized = torch.Generator()
        rng_for_deserialized.manual_seed(0)
        _simple_graph_snapshot_restoration(
            deserialized_graph, n_iter, rng=rng_for_deserialized
        )
        self.assertEqual(expected_res[n_iter:], list(it))
        self.assertEqual(expected_res[n_iter:], list(deserialized_graph))

    def test_simple_snapshot_graph_with_serialization(self):
        graph1 = dp.iter.IterableWrapper(range(10))
        res1 = list(range(10))
        self._snapshot_test_helper(graph1, expected_res=res1)

        graph2 = graph1.map(_mul_10)
        res2 = [10 * x for x in res1]
        self._snapshot_test_helper(graph2, expected_res=res2)

        rng = torch.Generator()
        graph3 = graph2.shuffle()
        rng.manual_seed(0)
        torch.utils.data.graph_settings.apply_random_seed(graph3, rng)
        res3 = list(graph3)
        self._snapshot_test_helper(graph3, expected_res=res3)

        graph4 = graph3.map(_mul_10)
        res4 = [10 * x for x in res3]
        self._snapshot_test_helper(graph4, expected_res=res4)

        batch_size = 2
        graph5 = graph4.batch(batch_size)
        res5 = [
            res4[i : i + batch_size] for i in range(0, len(res4), batch_size)
        ]  # .batch(2)
        self._snapshot_test_helper(graph5, expected_res=res5)

        # With `fork` and `zip`
        cdp1, cdp2 = graph5.fork(2)
        graph6 = cdp1.zip(cdp2)
        res6 = [(x, x) for x in res5]
        self._snapshot_test_helper(graph6, expected_res=res6)

        # With `fork` and `concat`
        graph7 = cdp1.concat(cdp2)
        res7 = res5 * 2
        self._snapshot_test_helper(graph7, expected_res=res7)

    def test_simple_snapshot_graph_repeated(self):
        cdp1, cdp2 = (
            dp.iter.IterableWrapper(range(10))
            .map(_mul_10)
            .shuffle()
            .map(_mul_10)
            .map(_mul_10)
            .fork(2)
        )
        graph = cdp1.zip(cdp2)

        rng = torch.Generator()
        rng.manual_seed(0)
        torch.utils.data.graph_settings.apply_random_seed(graph, rng)

        # Get expected result
        expected_res = list(graph)

        rng.manual_seed(0)
        torch.utils.data.graph_settings.apply_random_seed(graph, rng)
        it = iter(graph)
        n_iter = 3
        for _ in range(n_iter):
            next(it)

        # First serialization/deserialization
        serialized_graph = pickle.dumps(graph)
        deserialized_graph = pickle.loads(serialized_graph)

        rng_for_deserialized = torch.Generator()
        rng_for_deserialized.manual_seed(0)
        _simple_graph_snapshot_restoration(
            deserialized_graph,
            deserialized_graph._number_of_samples_yielded,
            rng=rng_for_deserialized,
        )

        it = iter(deserialized_graph)
        # Get the next element and ensure it is as expected
        self.assertEqual(expected_res[3], next(it))

        # Serializalize/Deserialize and fast-forward again after to ensure it works
        serialized_graph2 = pickle.dumps(deserialized_graph)
        deserialized_graph2 = pickle.loads(serialized_graph2)

        rng_for_deserialized = torch.Generator()
        rng_for_deserialized.manual_seed(0)
        _simple_graph_snapshot_restoration(
            deserialized_graph2,
            deserialized_graph._number_of_samples_yielded,
            rng=rng_for_deserialized,
        )

        # Get the next element and ensure it is as expected
        self.assertEqual(expected_res[4:], list(deserialized_graph2))


if __name__ == "__main__":
    run_tests()
