# Owner(s): ["module: unknown"]

import glob
import io
import os
import unittest

import torch
from torch.testing._internal.common_utils import run_tests, TestCase


try:
    from third_party.build_bundled import create_bundled
except ImportError:
    create_bundled = None

license_file = "third_party/LICENSES_BUNDLED.txt"
starting_txt = "The PyTorch repository and source distributions bundle"
site_packages = os.path.dirname(os.path.dirname(torch.__file__))
distinfo = glob.glob(os.path.join(site_packages, "torch-*dist-info"))


class TestLicense(TestCase):
    @unittest.skipIf(not create_bundled, "can only be run in a source tree")
    def test_license_for_wheel(self):
        current = io.StringIO()
        create_bundled("third_party", current)
        with open(license_file) as fid:
            src_tree = fid.read()
        if not src_tree == current.getvalue():
            raise AssertionError(
                f'the contents of "{license_file}" do not '
                "match the current state of the third_party files. Use "
                '"python third_party/build_bundled.py" to regenerate it'
            )

    @unittest.skipIf(len(distinfo) == 0, "no installation in site-package to test")
    def test_distinfo_license(self):
        """If run when pytorch is installed via a wheel, the license will be in
        site-package/torch-*dist-info/LICENSE. Make sure it contains the third
        party bundle of licenses"""

        if len(distinfo) > 1:
            raise AssertionError(
                'Found too many "torch-*dist-info" directories '
                f'in "{site_packages}, expected only one'
            )
        with open(os.path.join(os.path.join(distinfo[0], "LICENSE"))) as fid:
            txt = fid.read()
            self.assertTrue(starting_txt in txt)


if __name__ == "__main__":
    run_tests()
