Intermittent patch to TFRT to submit a TF/TFRT cross-cutting change.
This patch will be applied only until TF's TFRT commit is automatically bumped.

---

diff --git a/backends/gpu/include/tfrt/gpu/gpu_types.h b/backends/gpu/include/tfrt/gpu/gpu_types.h
index 3d311c3..a216716 100644
--- a/backends/gpu/include/tfrt/gpu/gpu_types.h
+++ b/backends/gpu/include/tfrt/gpu/gpu_types.h
@@ -295,11 +295,7 @@
       wrapper::CurrentContext current, wrapper::Stream stream,
       wrapper::CclComm comm)>;
 
-  explicit GpuCclHandle(AsyncValueRef<GpuContext> context,
-                        wrapper::OwningCclComm comm, int num_ranks);
-  // TODO(hanbinyoon): Remove after transitioning to the above constructor.
-  explicit GpuCclHandle(AsyncValueRef<GpuContext> context,
-                        wrapper::OwningCclComm comm);
+  GpuCclHandle(AsyncValueRef<GpuContext> context, wrapper::OwningCclComm comm);
   ~GpuCclHandle();
 
   GpuCclHandle(GpuCclHandle&&) = default;
@@ -311,8 +307,6 @@
   llvm::Error ExecuteCallbacks(wrapper::CurrentContext current,
                                wrapper::Stream stream);
 
-  int num_ranks() const { return num_ranks_; }
-
   const wrapper::OwningCclComm& operator->() const { return comm_; }
   wrapper::CclComm get() const { return comm_.get(); }
   wrapper::CclComm release();
@@ -322,7 +316,6 @@
  private:
   AsyncValueRef<GpuContext> context_;
   wrapper::OwningCclComm comm_;
-  int num_ranks_;
   std::vector<Callback> callbacks_;
 };
 
diff --git a/backends/gpu/lib/gpu_types.cc b/backends/gpu/lib/gpu_types.cc
index 38529bc..01e3dba 100644
--- a/backends/gpu/lib/gpu_types.cc
+++ b/backends/gpu/lib/gpu_types.cc
@@ -214,15 +214,8 @@
 GpuBlasHandle::~GpuBlasHandle() = default;
 
 GpuCclHandle::GpuCclHandle(AsyncValueRef<GpuContext> context,
-                           wrapper::OwningCclComm comm, int num_ranks)
-    : context_(std::move(context)),
-      comm_(std::move(comm)),
-      num_ranks_(num_ranks) {}
-
-// TODO(hanbinyoon): Remove after transitioning to the above constructor.
-GpuCclHandle::GpuCclHandle(AsyncValueRef<GpuContext> context,
                            wrapper::OwningCclComm comm)
-    : context_(std::move(context)), comm_(std::move(comm)), num_ranks_(0) {}
+    : context_(std::move(context)), comm_(std::move(comm)) {}
 
 GpuCclHandle::~GpuCclHandle() = default;
 
diff --git a/backends/gpu/lib/kernels/ccl_kernels.cc b/backends/gpu/lib/kernels/ccl_kernels.cc
index 52ce820..9cfc1de 100644
--- a/backends/gpu/lib/kernels/ccl_kernels.cc
+++ b/backends/gpu/lib/kernels/ccl_kernels.cc
@@ -107,8 +107,6 @@
   auto width = ToWidthInBytes(type);
   if (!width) return width.takeError();
   assert(*width != 0);
-  if (input->size() != output->size() * handle->num_ranks())
-    return MakeStringError("Input size must be output size times ranks.");
 
   handle->AddCallback([input = input.ValueRef(), output = output.ValueRef(),
                        recvcount = output->size() / *width, type,
@@ -116,6 +114,10 @@
                           wrapper::CurrentContext current,
                           wrapper::Stream stream,
                           wrapper::CclComm comm) -> llvm::Error {
+    auto count = wrapper::CclCommCount(comm);
+    if (!count) return count.takeError();
+    if (input->size() != output->size() * *count)
+      return MakeStringError("Input size must be output size times ranks.");
     return wrapper::CclReduceScatter(current, input->pointer(),
                                      output->pointer(), recvcount, type, op,
                                      comm, stream);
