From 1766473944f58921b61e4e982d96099d52c5a3c4 Mon Sep 17 00:00:00 2001
From: Xdng Yng <wyverald@gmail.com>
Date: Mon, 22 Apr 2024 13:22:11 -0700
Subject: [PATCH] Enforce file name format for MODULE.bazel includes

They must end in `.MODULE.bazel`.

Follow-up for https://github.com/bazelbuild/bazel/issues/17880

Closes #22075.

PiperOrigin-RevId: 627136756
Change-Id: If9b1797f2e7ddc1aebd929646329e832288bfd8a
---
 .../lib/bazel/bzlmod/CompiledModuleFile.java  |  2 +-
 .../lib/bazel/bzlmod/ModuleFileFunction.java  |  8 +++
 .../bazel/bzlmod/CompiledModuleFileTest.java  |  5 +-
 .../bazel/bzlmod/ModuleFileFunctionTest.java  | 50 ++++++++++++-------
 src/test/py/bazel/bzlmod/bazel_module_test.py |  4 +-
 5 files changed, 48 insertions(+), 21 deletions(-)

diff --git a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/CompiledModuleFile.java b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/CompiledModuleFile.java
index 8d5b57e383b..2cdcacf79bd 100644
--- a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/CompiledModuleFile.java
+++ b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/CompiledModuleFile.java
@@ -142,7 +142,7 @@ public record CompiledModuleFile(
       @Override
       public void visit(DotExpression node) {
         visit(node.getObject());
-        if (includeWasAssigned || !node.getField().getName().equals(INCLUDE_IDENTIFIER)) {
+        if (!node.getField().getName().equals(INCLUDE_IDENTIFIER)) {
           // This is fine: `whatever.include`
           // (so `include` can be used as a tag class name)
           visit(node.getField());
diff --git a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunction.java b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunction.java
index 2eb45924d0f..e6ac3416528 100644
--- a/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunction.java
+++ b/src/main/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunction.java
@@ -351,6 +351,14 @@ public class ModuleFileFunction implements SkyFunction {
             includeStatement.includeLabel(),
             includeStatement.location());
       }
+      if (!includeStatement.includeLabel().endsWith(".MODULE.bazel")) {
+        throw errorf(
+            Code.BAD_MODULE,
+            "bad include label '%s' at %s: the file to be included must have a name ending in"
+                + " '.MODULE.bazel'",
+            includeStatement.includeLabel(),
+            includeStatement.location());
+      }
       try {
         includeLabels.add(Label.parseCanonical(includeStatement.includeLabel()));
       } catch (LabelSyntaxException e) {
diff --git a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/CompiledModuleFileTest.java b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/CompiledModuleFileTest.java
index 6a7125b6c3a..34bfdd6f935 100644
--- a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/CompiledModuleFileTest.java
+++ b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/CompiledModuleFileTest.java
@@ -91,13 +91,16 @@ public class CompiledModuleFileTest {
   public void checkSyntax_good_includeIdentifierReassigned() throws Exception {
     String program =
         """
+        include('world')
         include = print
         # from this point on, we no longer check anything about `include` usage.
         include('hello')
         str(include)
         exclude = include
         """;
-    assertThat(checkSyntax(program)).isEmpty();
+    assertThat(checkSyntax(program))
+        .containsExactly(
+            new IncludeStatement("world", Location.fromFileLineColumn("test file", 1, 1)));
   }
 
   @Test
diff --git a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java
index 76d6f4ebd9b..cea99f30491 100644
--- a/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java
+++ b/src/test/java/com/google/devtools/build/lib/bazel/bzlmod/ModuleFileFunctionTest.java
@@ -334,22 +334,22 @@ public class ModuleFileFunctionTest extends FoundationTestCase {
     scratch.overwriteFile(
         rootDirectory.getRelative("MODULE.bazel").getPathString(),
         "module(name='aaa')",
-        "include('//java:MODULE.bazel.segment')",
+        "include('//java:java.MODULE.bazel')",
         "bazel_dep(name='foo', version='1.0')",
         "register_toolchains('//:whatever')",
-        "include('//python:MODULE.bazel.segment')");
+        "include('//python:python.MODULE.bazel')");
     scratch.overwriteFile(rootDirectory.getRelative("java/BUILD").getPathString());
     scratch.overwriteFile(
-        rootDirectory.getRelative("java/MODULE.bazel.segment").getPathString(),
+        rootDirectory.getRelative("java/java.MODULE.bazel").getPathString(),
         "bazel_dep(name='java-foo', version='1.0')");
     scratch.overwriteFile(rootDirectory.getRelative("python/BUILD").getPathString());
     scratch.overwriteFile(
-        rootDirectory.getRelative("python/MODULE.bazel.segment").getPathString(),
+        rootDirectory.getRelative("python/python.MODULE.bazel").getPathString(),
         "bazel_dep(name='py-foo', version='1.0', repo_name='python-foo')",
         "single_version_override(module_name='java-foo', version='2.0')",
-        "include('//python:toolchains/MODULE.bazel.segment')");
+        "include('//python:toolchains/toolchains.MODULE.bazel')");
     scratch.overwriteFile(
-        rootDirectory.getRelative("python/toolchains/MODULE.bazel.segment").getPathString(),
+        rootDirectory.getRelative("python/toolchains/toolchains.MODULE.bazel").getPathString(),
         "register_toolchains('//:python-whatever')");
     FakeRegistry registry = registryFactory.newFakeRegistry("/foo");
     ModuleFileFunction.REGISTRIES.set(differencer, ImmutableList.of(registry.getUrl()));
@@ -387,7 +387,7 @@ public class ModuleFileFunctionTest extends FoundationTestCase {
     scratch.overwriteFile(
         rootDirectory.getRelative("MODULE.bazel").getPathString(),
         "module(name='aaa')",
-        "include('@haha//java:MODULE.bazel.segment')");
+        "include('@haha//java:java.MODULE.bazel')");
     FakeRegistry registry = registryFactory.newFakeRegistry("/foo");
     ModuleFileFunction.REGISTRIES.set(differencer, ImmutableList.of(registry.getUrl()));
 
@@ -403,7 +403,7 @@ public class ModuleFileFunctionTest extends FoundationTestCase {
     scratch.overwriteFile(
         rootDirectory.getRelative("MODULE.bazel").getPathString(),
         "module(name='aaa')",
-        "include(':MODULE.bazel.segment')");
+        "include(':relative.MODULE.bazel')");
     FakeRegistry registry = registryFactory.newFakeRegistry("/foo");
     ModuleFileFunction.REGISTRIES.set(differencer, ImmutableList.of(registry.getUrl()));
 
@@ -414,12 +414,28 @@ public class ModuleFileFunctionTest extends FoundationTestCase {
     assertThat(result.getError().toString()).contains("starting with double slashes");
   }
 
+  @Test
+  public void testRootModule_include_bad_notEndingInModuleBazel() throws Exception {
+    scratch.overwriteFile(
+        rootDirectory.getRelative("MODULE.bazel").getPathString(),
+        "module(name='aaa')",
+        "include('//:MODULE.bazel.segment')");
+    FakeRegistry registry = registryFactory.newFakeRegistry("/foo");
+    ModuleFileFunction.REGISTRIES.set(differencer, ImmutableList.of(registry.getUrl()));
+
+    EvaluationResult<RootModuleFileValue> result =
+        evaluator.evaluate(
+            ImmutableList.of(ModuleFileValue.KEY_FOR_ROOT_MODULE), evaluationContext);
+    assertThat(result.hasError()).isTrue();
+    assertThat(result.getError().toString()).contains("have a name ending in '.MODULE.bazel'");
+  }
+
   @Test
   public void testRootModule_include_bad_badLabelSyntax() throws Exception {
     scratch.overwriteFile(
         rootDirectory.getRelative("MODULE.bazel").getPathString(),
         "module(name='aaa')",
-        "include('//haha/:::')");
+        "include('//haha/:::.MODULE.bazel')");
     FakeRegistry registry = registryFactory.newFakeRegistry("/foo");
     ModuleFileFunction.REGISTRIES.set(differencer, ImmutableList.of(registry.getUrl()));
 
@@ -435,10 +451,10 @@ public class ModuleFileFunctionTest extends FoundationTestCase {
   public void testRootModule_include_bad_moduleAfterInclude() throws Exception {
     scratch.overwriteFile(
         rootDirectory.getRelative("MODULE.bazel").getPathString(),
-        "include('//java:MODULE.bazel.segment')");
+        "include('//java:java.MODULE.bazel')");
     scratch.overwriteFile(rootDirectory.getRelative("java/BUILD").getPathString());
     scratch.overwriteFile(
-        rootDirectory.getRelative("java/MODULE.bazel.segment").getPathString(),
+        rootDirectory.getRelative("java/java.MODULE.bazel").getPathString(),
         "module(name='bet-you-didnt-expect-this-didya')",
         "bazel_dep(name='java-foo', version='1.0', repo_name='foo')");
     FakeRegistry registry = registryFactory.newFakeRegistry("/foo");
@@ -457,15 +473,15 @@ public class ModuleFileFunctionTest extends FoundationTestCase {
     scratch.overwriteFile(
         rootDirectory.getRelative("MODULE.bazel").getPathString(),
         "module(name='aaa')",
-        "include('//java:MODULE.bazel.segment')",
-        "include('//python:MODULE.bazel.segment')");
+        "include('//java:java.MODULE.bazel')",
+        "include('//python:python.MODULE.bazel')");
     scratch.overwriteFile(rootDirectory.getRelative("java/BUILD").getPathString());
     scratch.overwriteFile(
-        rootDirectory.getRelative("java/MODULE.bazel.segment").getPathString(),
+        rootDirectory.getRelative("java/java.MODULE.bazel").getPathString(),
         "bazel_dep(name='java-foo', version='1.0', repo_name='foo')");
     scratch.overwriteFile(rootDirectory.getRelative("python/BUILD").getPathString());
     scratch.overwriteFile(
-        rootDirectory.getRelative("python/MODULE.bazel.segment").getPathString(),
+        rootDirectory.getRelative("python/python.MODULE.bazel").getPathString(),
         "bazel_dep(name='python-foo', version='1.0', repo_name='foo')");
     FakeRegistry registry = registryFactory.newFakeRegistry("/foo");
     ModuleFileFunction.REGISTRIES.set(differencer, ImmutableList.of(registry.getUrl()));
@@ -484,10 +500,10 @@ public class ModuleFileFunctionTest extends FoundationTestCase {
         rootDirectory.getRelative("MODULE.bazel").getPathString(),
         "module(name='aaa')",
         "FOO_NAME = 'foo'",
-        "include('//java:MODULE.bazel.segment')");
+        "include('//java:java.MODULE.bazel')");
     scratch.overwriteFile(rootDirectory.getRelative("java/BUILD").getPathString());
     scratch.overwriteFile(
-        rootDirectory.getRelative("java/MODULE.bazel.segment").getPathString(),
+        rootDirectory.getRelative("java/java.MODULE.bazel").getPathString(),
         "bazel_dep(name=FOO_NAME, version='1.0')");
     FakeRegistry registry = registryFactory.newFakeRegistry("/foo");
     ModuleFileFunction.REGISTRIES.set(differencer, ImmutableList.of(registry.getUrl()));
diff --git a/src/test/py/bazel/bzlmod/bazel_module_test.py b/src/test/py/bazel/bzlmod/bazel_module_test.py
index ef71d338c97..41915a6bce3 100644
--- a/src/test/py/bazel/bzlmod/bazel_module_test.py
+++ b/src/test/py/bazel/bzlmod/bazel_module_test.py
@@ -906,12 +906,12 @@ class BazelModuleTest(test_base.TestBase):
         [
             'module(name="foo")',
             'bazel_dep(name="bbb", version="1.0")',
-            'include("//java:MODULE.bazel.segment")',
+            'include("//java:java.MODULE.bazel")',
         ],
     )
     self.ScratchFile('java/BUILD')
     self.ScratchFile(
-        'java/MODULE.bazel.segment',
+        'java/java.MODULE.bazel',
         [
             'bazel_dep(name="aaa", version="1.0", repo_name="lol")',
         ],
-- 
GitLab