/* Copyright 2017 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/cpu/dot_op_emitter.h"

#include <algorithm>
#include <cstdint>
#include <functional>
#include <iterator>
#include <memory>
#include <numeric>
#include <optional>
#include <string>
#include <tuple>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CallingConv.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/Alignment.h"
#include "llvm/Support/Casting.h"
#include "xla/backends/cpu/codegen/target_machine_features.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/layout.h"
#include "xla/layout_util.h"
#include "xla/primitive_util.h"
#include "xla/service/cpu/backend_config.pb.h"
#include "xla/service/cpu/cpu_options.h"
#include "xla/service/cpu/cpu_runtime.h"
#include "xla/service/cpu/tiled_dot_emitter.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/llvm_ir/ir_array.h"
#include "xla/service/llvm_ir/kernel_support_library.h"
#include "xla/service/llvm_ir/llvm_loop.h"
#include "xla/service/llvm_ir/llvm_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/tsl/lib/math/math_util.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"

namespace xla {

using llvm_ir::SetToFirstInsertPoint;

namespace cpu {

namespace {
// Returns true if we should call into multi-threaded Eigen routines.
bool ShouldUseMultiThreadedEigen(const HloModuleConfig& config) {
  return config.debug_options().xla_cpu_multi_thread_eigen();
}

// Return whether the given shape is rank 2.
bool IsRank2(const Shape& shape) { return shape.dimensions().size() == 2; }

bool IsSimpleLayout(const Layout& layout) { return layout.tiles().empty(); }

// In a gemm operation where output = lhs * rhs, check whether the given shapes
// are valid for the operation.
bool AreGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
                   const Shape& output_shape,
                   const TargetMachineFeatures& target_machine_features) {
  CHECK(!lhs_shape.has_layout() || IsSimpleLayout(lhs_shape.layout()))
      << lhs_shape.ToString();
  CHECK(!rhs_shape.has_layout() || IsSimpleLayout(rhs_shape.layout()))
      << rhs_shape.ToString();
  CHECK(!output_shape.has_layout() || IsSimpleLayout(output_shape.layout()))
      << output_shape.ToString();

  switch (output_shape.element_type()) {
    case F16:
    case F32:
    case F64:
    case C64:
    case C128:
    case S32:
      return IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape);
    default:
      return false;
  }
}

bool IsAlignedGemm(const DotInfo& dot_info,
                   const TargetMachineFeatures& target_machine_features) {
  if (ShapeUtil::IsZeroElementArray(dot_info.lhs_shape) ||
      ShapeUtil::IsZeroElementArray(dot_info.rhs_shape)) {
    return false;
  }

  return AreGemmShapes(dot_info.lhs_shape, dot_info.rhs_shape,
                       dot_info.result_shape, target_machine_features);
}

bool CanEmitTiledLlvmIrGemm(
    const HloModuleConfig& config, const DotInfo& dot_info,
    const TargetMachineFeatures& target_machine_features) {
  CHECK(IsAlignedGemm(dot_info, target_machine_features));

  if (ShouldUseMultiThreadedEigen(config)) {
    return false;
  }

  int m = dot_info.result_shape.dimensions(0);
  int k = dot_info.lhs_shape.dimensions(
      dot_info.dim_nums.lhs_contracting_dimensions(0));
  int n = dot_info.result_shape.dimensions(1);

  if (!options::ForceEnableExperimentalLlvmIrGemm(config)) {
    // TODO(sanjoy):  We should make these numbers micro-arch specific.
    bool small_gemm =
        k <= 128 && ((m <= 32 && n <= 128) || (m <= 128 && n <= 32));
    if (!small_gemm) {
      return false;
    }
  }

  bool lhs_canonical = dot_info.dim_nums.lhs_contracting_dimensions(0) == 1;
  bool rhs_canonical = dot_info.dim_nums.rhs_contracting_dimensions(0) == 0;

  if (!(lhs_canonical && rhs_canonical)) {
    return false;
  }

  if (dot_info.result_shape.element_type() == F16 ||
      dot_info.result_shape.element_type() == C64 ||
      dot_info.result_shape.element_type() == C128) {
    // TODO(sanjoy): This is probably easy to fix, but I want to keep the CL
    // adding this comment NFC.
    return false;
  }

  return true;
}

// Returns dot implementation strategy for non-batch dot operations.
DotImplementationStrategy GetNonBatchDotImplementationStrategy(
    const HloModuleConfig& config, const DotInfo& dot_info,
    const TargetMachineFeatures& target_machine_features,
    bool allow_runtime_calls) {
  PrimitiveType element_type = dot_info.result_shape.element_type();

  // Batched dot either handled by a runtime call or expanded into a sequence
  // of non-batch dot operations.
  DCHECK(dot_info.dim_nums.lhs_batch_dimensions_size() == 0 &&
         dot_info.dim_nums.rhs_batch_dimensions_size() == 0)
      << "Dot operations must be non-batch";

  // Any Matrix-Vector product of floating point or integral type, or
  // a transpose-dot fusion of the same can be lowered to a tiled LLVM
  // IR implementation.
  if ((dot_info.result_shape.dimensions().size() <= 1 ||
       (dot_info.result_shape.dimensions().size() == 2 &&
        (dot_info.result_shape.dimensions(0) == 1 ||
         dot_info.result_shape.dimensions(1) == 1))) &&
      (primitive_util::IsFloatingPointType(element_type) ||
       primitive_util::IsIntegralType(element_type))) {
    return DotImplementationStrategy::kTiledLlvmIrGemv;
  }

  // MatMul smaller than 3x3 should use naive nested loop.
  if ((dot_info.lhs_shape.dimensions().size() <= 1 ||
       (dot_info.lhs_shape.dimensions().size() == 2 &&
        (dot_info.lhs_shape.dimensions(0) <= 3 ||
         dot_info.lhs_shape.dimensions(1) <= 3))) &&
      (dot_info.rhs_shape.dimensions().size() <= 1 ||
       (dot_info.rhs_shape.dimensions().size() == 2 &&
        (dot_info.rhs_shape.dimensions(0) <= 3 ||
         dot_info.rhs_shape.dimensions(1) <= 3))) &&
      (primitive_util::IsFloatingPointType(element_type) ||
       primitive_util::IsIntegralType(element_type))) {
    return DotImplementationStrategy::kNaiveLlvmIr;
  }

  if (IsAlignedGemm(dot_info, target_machine_features)) {
    if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) {
      return DotImplementationStrategy::kTiledLlvmIrGemm;
    } else if (allow_runtime_calls) {
      return DotImplementationStrategy::kEigen;
    }
  }

  return DotImplementationStrategy::kNaiveLlvmIr;
}

// Helper class for emitting LLVM IR to perform the dot operation.
class DotOpEmitter {
 public:
  explicit DotOpEmitter(
      DotInfo dot_info, std::string dot_hlo_name,
      const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
      const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array,
      llvm::Value* work_group_id, llvm::Value* executable_run_options_value,
      llvm::IRBuilderBase* b, const HloModuleConfig& hlo_module_config,
      const TargetMachineFeatures& target_machine_features,
      bool allow_runtime_calls, bool allow_parallelism);

  // Emits the IR to perform the dot operation. Returns the number of workgroups
  // along the X dimension that can be used to parallelize the dot operation.
  absl::StatusOr<uint64_t> Emit();

  // Emits the IR to perform the batch dot operation.
  absl::Status EmitBatch();

 private:
  // Emits instructions to perform a scalar dot product (a multiply of the
  // LHS and RHS) and store the results in the target.
  absl::Status EmitScalarDot();

  // Emits a call to the CPU runtime to perform the matrix multiply.
  absl::Status EmitCallToRuntime();

  // Emits a call to the CPU runtime to perform the batch matrix multiply.
  absl::Status EmitCallToBatchRuntime();

  // Represents the dimensions of a matrix-matrix multiply operation.
  struct MatMultDims {
    // The number of rows in the LHS.
    int64_t m;

    // The number of columns in the LHS, which is also must be equal to the
    // number of rows in the RHS.
    int64_t k;

    // The number of columns on the RHS.
    int64_t n;

    // True if the LHS matrix is column major.
    bool lhs_column_major;

    // True if the LHS contraction dimension is 1.
    bool lhs_canonical;

    // True if the RHS matrix is column major.
    bool rhs_column_major;

    // True if the RHS contraction dimension is 0.
    bool rhs_canonical;
  };

  // Get the MatMultDims instance for the dot product this DotOpEmitter
  // represents.  Precondition: the dot is of rank 2 (and thus its operands are
  // of rank 2 as well).
  MatMultDims GetMatMultDims() const;

  // Get the MatMultDims instance for the dot product this DotOpEmitter
  // represents.  Precondition: the dot is of rank 3 (and thus its operands are
  // of rank 3 as well).
  MatMultDims GetBatchMatMultDims() const;

  // Lowers the dot operation as a tiled Matrix*Vector loop. Returns the number
  // of workgroups along the X dimension that can be used to parallelize the
  // dot operation.
  int64_t EmitTiledLlvmIrGemv();

  // Lowers the dot operation as a tiled Matrix*Matrix loop.
  void EmitTiledLlvmIrGemm();

  // Lowers the dot operation as a naive nested loop that computes the result
  // one element at a time.
  void EmitNaiveLlvmIrGemm();

  // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
  // registers.
  int64_t GetGemvTilingFactor() const {
    const int64_t kDefaultTilingFactor = 8;
    return options::LlvmIrGemvTilingFactor(hlo_module_config_)
        .value_or(kDefaultTilingFactor);
  }

  std::tuple<int64_t, int64_t, int64_t> GetGemmTileSize() const {
    // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz
    //
    // TODO(b/80093688): Tune for other architectures and centralize this
    // information in one place.
    const std::tuple<int64_t, int64_t, int64_t> kDefaultTileSize =
        std::tuple<int64_t, int64_t, int64_t>(11, 9, 1);
    return options::LlvmIrGemmTileSize(hlo_module_config_)
        .value_or(kDefaultTileSize);
  }

  DotInfo dot_info_;
  std::string dot_hlo_name_;
  const llvm_ir::IrArray& target_array_;
  const llvm_ir::IrArray& lhs_array_;
  const llvm_ir::IrArray& rhs_array_;
  const llvm_ir::IrArray* addend_array_;
  llvm::Value* work_group_id_;
  llvm::Value* executable_run_options_value_;
  llvm::IRBuilderBase* b_;
  const HloModuleConfig& hlo_module_config_;
  const TargetMachineFeatures& target_machine_features_;
  bool allow_runtime_calls_;
  bool allow_parallelism_;
};
}  // namespace

DotOpEmitter::DotOpEmitter(
    DotInfo dot_info, std::string dot_hlo_name,
    const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
    const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array,
    llvm::Value* work_group_id, llvm::Value* executable_run_options_value,
    llvm::IRBuilderBase* b, const HloModuleConfig& hlo_module_config,
    const TargetMachineFeatures& target_machine_features,
    bool allow_runtime_calls, bool allow_parallelism)
    : dot_info_(std::move(dot_info)),
      dot_hlo_name_(std::move(dot_hlo_name)),
      target_array_(target_array),
      lhs_array_(lhs_array),
      rhs_array_(rhs_array),
      addend_array_(addend_array),
      work_group_id_(work_group_id),
      executable_run_options_value_(executable_run_options_value),
      b_(b),
      hlo_module_config_(hlo_module_config),
      target_machine_features_(target_machine_features),
      allow_runtime_calls_(allow_runtime_calls),
      allow_parallelism_(allow_parallelism) {}

void DotOpEmitter::EmitTiledLlvmIrGemm() {
  PrimitiveType primitive_type = dot_info_.result_shape.element_type();
  MatMultDims mat_mult_dims = GetMatMultDims();

  llvm::Value* lhs = lhs_array_.GetBasePointer();
  llvm::Value* rhs = rhs_array_.GetBasePointer();
  llvm::Value* target = target_array_.GetBasePointer();
  int64_t m = mat_mult_dims.m;
  int64_t k = mat_mult_dims.k;
  int64_t n = mat_mult_dims.n;

  if (mat_mult_dims.lhs_column_major) {
    std::swap(lhs, rhs);
    std::swap(m, n);
  }

  int64_t size_bytes =
      m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
  b_->CreateMemSet(target, b_->getInt8(0), /*Size=*/size_bytes,
                   /*Align=*/llvm::MaybeAlign(1));

  int64_t max_target_vector_width =
      target_machine_features_.vector_register_num_elements(
          *b_->GetInsertBlock()->getParent(), primitive_type);

  int64_t tile_size_m, tile_size_k, tile_size_n_in_vector_width;
  std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) =
      GetGemmTileSize();

  EmitSmallGemm(
      /*scalar_type=*/primitive_type,
      /*m=*/m, /*k=*/k, /*n=*/n,
      /*max_vectorization_width=*/max_target_vector_width,
      /*max_vector_count=*/tile_size_n_in_vector_width,
      /*min_vectorization_width=*/std::min<int64_t>(4, max_target_vector_width),
      /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k, /*lhs=*/lhs,
      /*rhs=*/rhs, /*result=*/target, b_, hlo_module_config_);
}

int64_t DotOpEmitter::EmitTiledLlvmIrGemv() {
  PrimitiveType primitive_type = dot_info_.result_shape.element_type();

  CHECK(primitive_util::IsFloatingPointType(primitive_type) ||
        primitive_util::IsIntegralType(primitive_type));

  MatMultDims mat_mult_dims = GetMatMultDims();
  bool is_column_major_matrix_vector_gemv = false;
  bool is_row_major_matrix_vector_gemv = false;

  int64_t m, k;
  bool swap_operands;

  if (mat_mult_dims.m == 1) {
    // Our emitters can only do Matrix*Vector (abbreviated as M*V) but when M=1
    // we actually want V*M.  We implement V*M as follows (Tr(X) = Transpose of
    // X):
    //
    //   V*M = Tr(Tr(V*M))  // Tr(Tr(X)) == X
    //       = Tr(Tr(M) * Tr(V))  // Tr(A * B) == Tr(B) * Tr(A)
    //
    // Since transposing a vector is physically a no-op, this is really
    // equivalent to `Tr(M) * V`.  We further implement Tr(M) by pretending that
    // M is row major if it is actually column major and vice-versa.

    bool rhs_effectively_column_major = mat_mult_dims.rhs_canonical
                                            ? mat_mult_dims.rhs_column_major
                                            : !mat_mult_dims.rhs_column_major;

    if (rhs_effectively_column_major) {
      k = mat_mult_dims.k;
      m = mat_mult_dims.n;

      // We set is_row_major_matrix_vector_gemv and not
      // is_column_major_matrix_vector_gemv to implement the Transpose trick
      // mentioned above.
      is_row_major_matrix_vector_gemv = true;
      swap_operands = true;
    } else {
      k = mat_mult_dims.k;
      m = mat_mult_dims.n;

      // We set is_column_major_matrix_vector_gemv and not
      // is_row_major_matrix_vector_gemv to implement the Transpose trick
      // mentioned above.
      is_column_major_matrix_vector_gemv = true;
      swap_operands = true;
    }
  }

  if (mat_mult_dims.n == 1) {
    bool lhs_effectively_column_major = mat_mult_dims.lhs_canonical
                                            ? mat_mult_dims.lhs_column_major
                                            : !mat_mult_dims.lhs_column_major;

    if (lhs_effectively_column_major) {
      m = mat_mult_dims.m;
      k = mat_mult_dims.k;
      is_column_major_matrix_vector_gemv = true;
      swap_operands = false;
    } else {
      m = mat_mult_dims.m;
      k = mat_mult_dims.k;
      is_row_major_matrix_vector_gemv = true;
      swap_operands = false;
    }
  }

  CHECK(is_column_major_matrix_vector_gemv || is_row_major_matrix_vector_gemv);

  int64_t tiling_factor = GetGemvTilingFactor();
  CHECK_GT(tiling_factor, 0);

  llvm::Value* result_op = target_array_.GetBasePointer();
  llvm::Value* lhs_op =
      swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer();
  llvm::Value* rhs_op =
      swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();

  const int target_vector_register_element_size =
      target_machine_features_.vector_register_num_elements(
          *b_->GetInsertBlock()->getParent(), primitive_type);

  // We may not always know the vector register size for the target we're
  // compiling against, in which case target_vector_register_element_size is 0.
  // In these cases we choose a default LLVM IR register size.
  const int kUnknownTargetVectorRegisterSize = 4;
  const int vector_register_element_size =
      target_vector_register_element_size == 0
          ? kUnknownTargetVectorRegisterSize
          : target_vector_register_element_size;

// We parallelize the GEMV computation to have at least this many FMA
// instructions per task. In debug builds we prefer smaller tasks to test that
// we correctly parallelize the loop.
#ifdef NDEBUG
  static constexpr int64_t kFmaPerTask = 1 << 19;  // 0.5M FMA/task
#else
  static constexpr int64_t kFmaPerTask = 1 << 12;  // 4096 FMA/task
#endif

  // GEMV has very little data reuse, and we hit memory bandwidth bound
  // before we hit compute bound. So we limit the number of tasks to avoid
  // excessive task scheduling overheads.
  static constexpr int64_t kMaxTasks = 8;

  // Compute into how many tasks along the parallel dimension 'm' we can divide
  // the work (we do accumulation along the k dimension).
  int64_t m_per_task = tsl::MathUtil::CeilOfRatio(kFmaPerTask, k);
  int64_t num_tasks =
      std::min(tsl::MathUtil::CeilOfRatio(m, m_per_task), kMaxTasks);

  // If parallelism is not allowed, we always assume that we execute one task.
  if (!allow_parallelism_) {
    num_tasks = 1;
  }

  if (is_column_major_matrix_vector_gemv) {
    VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
            << " and k = " << k << "; num_tasks = " << num_tasks;

    EmitColumnMajorGemv(
        /*scalar_type=*/primitive_type, num_tasks, work_group_id_,
        /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor,
        /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
        /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
        /*result=*/result_op, b_, hlo_module_config_);
    return num_tasks;
  } else {
    VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
            << " and k = " << k << "; num_tasks = " << num_tasks;

    EmitRowMajorGemv(
        /*scalar_type=*/primitive_type, num_tasks, work_group_id_,
        /*tile_rows=*/tiling_factor,
        /*tile_cols=*/vector_register_element_size,
        /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op,
        /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr,
        /*result=*/result_op, b_, hlo_module_config_);
    return num_tasks;
  }
}

absl::StatusOr<uint64_t> DotOpEmitter::Emit() {
  // The dot operation performs a sum of products over dimension 0 of the left
  // hand side operand and dimension 1 of the right hand side operand.
  //
  // Let the shapes of lhs and rhs be defined as below:
  //
  //   lhs = [L{n-1} x L{n-2} x ... L{0}]
  //   rhs = [R{m-1} x R{m-2} x ... R{0}]
  //
  // The sum-of-products dimension in the lhs has size L{0} and the dimension in
  // the rhs has size R{1}. Necessarily, then:
  //
  //   L{0} == R{1}
  //
  // The output of the operation has the following shape:
  //
  //   output = [L{n-1} x L{n-2} x ... L{1} x R{m-1} x R{m-2} x ... R{2} x R{0}]
  //
  // To perform the operation we construct a loop nest with one for-loop for
  // each dimension of the output. Inside this loop nest is another for-loop
  // which performs the sum-of-products (the reduction loop) before storing
  // the result in the output buffer.

  const Shape& lhs_shape = lhs_array_.GetShape();
  const Shape& rhs_shape = rhs_array_.GetShape();

  if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) {
    // If the operands are scalar, don't emit any loops.
    TF_RET_CHECK(ShapeUtil::IsScalar(lhs_shape) &&
                 ShapeUtil::IsScalar(rhs_shape));
    TF_RETURN_IF_ERROR(EmitScalarDot());
    return 1;
  }

  switch (GetNonBatchDotImplementationStrategy(hlo_module_config_, dot_info_,
                                               target_machine_features_,
                                               allow_runtime_calls_)) {
    case DotImplementationStrategy::kNaiveLlvmIr:
      EmitNaiveLlvmIrGemm();
      return 1;

    case DotImplementationStrategy::kTiledLlvmIrGemv:
      return EmitTiledLlvmIrGemv();

    case DotImplementationStrategy::kTiledLlvmIrGemm:
      EmitTiledLlvmIrGemm();
      return 1;

    case DotImplementationStrategy::kEigen:
      TF_RETURN_IF_ERROR(EmitCallToRuntime());
      return 1;
  }
}

absl::Status DotOpEmitter::EmitBatch() {
  // The dot operation performs a sum of products over dimension 0 of the left
  // hand side operand and dimension 1 of the right hand side operand.
  //
  // Let the shapes of lhs and rhs be defined as below:
  //
  //   lhs = [L{n-1} x L{n-2} x ... L{0}]
  //   rhs = [R{m-1} x R{m-2} x ... R{0}]
  //
  // The sum-of-products dimension in the lhs has size L{0} and the dimension in
  // the rhs has size R{1}. Necessarily, then:
  //
  //   L{0} == R{1}
  //
  // The output of the operation has the following shape:
  //
  //   output = [L{n-1} x L{n-2} x ... L{1} x R{m-1} x R{m-2} x ... R{2} x R{0}]
  //
  // To perform the operation we construct a loop nest with one for-loop for
  // each dimension of the output. Inside this loop nest is another for-loop
  // which performs the sum-of-products (the reduction loop) before storing
  // the result in the output buffer.

  return EmitCallToBatchRuntime();
}

void DotOpEmitter::EmitNaiveLlvmIrGemm() {
  CHECK_EQ(addend_array_, nullptr);

  const Shape& lhs_shape = lhs_array_.GetShape();
  const Shape& rhs_shape = rhs_array_.GetShape();
  const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;

  // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special
  // case where the reduction dimension is 0 for both LHS and RHS. This results
  // in a vector dot product producing a scalar.
  int64_t lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0);
  int64_t rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0);

  // Verify the reduction dimension in the two operands are the same size.
  CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension),
           rhs_shape.dimensions(rhs_reduction_dimension));

  bool lhs_reduction_along_minor_dimension =
      lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0);
  bool rhs_reduction_along_minor_dimension =
      rhs_reduction_dimension == LayoutUtil::Minor(rhs_shape.layout(), 0);

  // Create loop nests which loop through the LHS operand dimensions and the RHS
  // operand dimensions. The reduction dimension of the LHS and RHS are handled
  // in a separate innermost loop which performs the sum of products.
  llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(dot_hlo_name_), b_);
  std::vector<llvm::Value*> lhs_multi_index =
      loop_nest.EmitOperandArrayLoopNest(
          lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs");
  std::vector<llvm::Value*> rhs_multi_index =
      loop_nest.EmitOperandArrayLoopNest(
          rhs_array_, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");

  // Create the loop which does the sum of products reduction.
  //
  // The prevent_unrolling bit is working around a deficiency in LLVM's loop
  // vectorization pipeline, wherein in some cases unrolling a loop can prevent
  // effective vectorization.  Since we know that the IR we generate when
  // reducing across the minor dimension in both LHS and RHS is vectorized well
  // by the loop vectorizer, we block unrolling in that case to stop loop unroll
  // from messing up the vectorization.
  std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
      0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction",
      /*unroll_mode=*/
      (lhs_reduction_along_minor_dimension &&
       rhs_reduction_along_minor_dimension)
          ? xla::llvm_ir::UnrollMode::kNoUnroll
          : xla::llvm_ir::UnrollMode::kDefaultUnroll);

  // The final entry in the rhs and lhs indexes is the indvar of the
  // reduction loop.
  lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
  llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_shape,
                                    b_->getInt64Ty());
  rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
  llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_shape,
                                    b_->getInt64Ty());

  // For computing the sum of products we alloca a single location to store the
  // dot product result as we accumulate it within the reduction loop. After the
  // reduction loop we load the result and store into the output array.

  // Function entry basic block.
  // - Emit alloca for accumulator
  llvm::Function* func = reduction_loop->GetPreheaderBasicBlock()->getParent();
  SetToFirstInsertPoint(&func->getEntryBlock(), b_);
  llvm::Type* accum_type = target_array_.GetElementLlvmType();
  llvm::Value* accum_address =
      b_->CreateAlloca(accum_type, /*ArraySize=*/nullptr, "accum_address");

  // Preheader basic block of reduction loop:
  // - Initialize accumulator to zero.
  llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock();
  b_->SetInsertPoint(preheader_bb->getTerminator());

  b_->CreateStore(llvm::Constant::getNullValue(accum_type), accum_address);

  // Body basic block of reduction loop:
  // - Load elements from lhs and rhs array.
  // - Multiply lhs-element and rhs-element.
  // - Load accumulator and add to product.
  // - Store sum back into accumulator.
  SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), b_);

  llvm::Value* lhs_element = lhs_array_.EmitReadArrayElement(lhs_index, b_);
  llvm::Value* rhs_element = rhs_array_.EmitReadArrayElement(rhs_index, b_);

  llvm::Value* accum = b_->CreateLoad(accum_type, accum_address);
  llvm::Value* updated_accum;
  if (ShapeUtil::ElementIsComplex(lhs_shape)) {
    auto real = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {0}); };
    auto imag = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {1}); };
    llvm::Value* product_real =
        b_->CreateFSub(b_->CreateFMul(real(lhs_element), real(rhs_element)),
                       b_->CreateFMul(imag(lhs_element), imag(rhs_element)));
    llvm::Value* product_imag =
        b_->CreateFAdd(b_->CreateFMul(real(lhs_element), imag(rhs_element)),
                       b_->CreateFMul(imag(lhs_element), real(rhs_element)));
    updated_accum = b_->CreateInsertValue(
        accum, b_->CreateFAdd(real(accum), product_real), {0});
    updated_accum = b_->CreateInsertValue(
        updated_accum, b_->CreateFAdd(imag(accum), product_imag), {1});
  } else if (ShapeUtil::ElementIsIntegral(lhs_shape)) {
    llvm::Value* product = b_->CreateMul(lhs_element, rhs_element);
    updated_accum = b_->CreateAdd(accum, product);
  } else if (lhs_shape.element_type() == PRED) {
    llvm::Value* product = b_->CreateAnd(lhs_element, rhs_element);
    updated_accum = b_->CreateOr(accum, product);
  } else {
    llvm::Value* product = b_->CreateFMul(lhs_element, rhs_element);
    updated_accum = b_->CreateFAdd(accum, product);
  }
  b_->CreateStore(updated_accum, accum_address);

  // Exit basic block of reduction loop.
  // - Load accumulator value (the result).
  // - Store into output array.
  SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), b_);

  llvm::Value* result = b_->CreateLoad(accum_type, accum_address);

  // Create index into target address. The target index is the concatenation of
  // the rhs and lhs indexes with the reduction dimensions removed. The terms
  // from the rhs index are the lower dimensions in the index so we add them
  // first.
  std::vector<llvm::Value*> target_multi_index;
  for (int dimension = 0; dimension < lhs_index.size(); ++dimension) {
    if (dimension != lhs_reduction_dimension) {
      target_multi_index.push_back(lhs_index[dimension]);
    }
  }
  for (int dimension = 0; dimension < rhs_index.size(); ++dimension) {
    if (dimension != rhs_reduction_dimension) {
      target_multi_index.push_back(rhs_index[dimension]);
    }
  }

  llvm_ir::IrArray::Index target_index(
      target_multi_index, target_array_.GetShape(), lhs_index.GetType());
  target_array_.EmitWriteArrayElement(target_index, result, b_);

  // Set the IR builder insert point to the exit basic block of the outer most
  // loop.
  b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
}

absl::Status DotOpEmitter::EmitScalarDot() {
  // A scalar dot is just a scalar multiply.
  llvm::Value* result;
  // Use the same index_type for all tensor accesses in the same kernel.
  llvm::Type* index_type = b_->getInt64Ty();
  llvm_ir::IrArray::Index element_index(index_type);
  llvm::Value* lhs_value =
      lhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
  llvm::Value* rhs_value =
      rhs_array_.EmitReadArrayElement(/*index=*/element_index, b_);
  if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) {
    auto get_real = [&](llvm::Value* x) {
      return b_->CreateExtractValue(x, {0});
    };

    auto get_imag = [&](llvm::Value* x) {
      return b_->CreateExtractValue(x, {1});
    };

    llvm::Value* real = b_->CreateFSub(
        b_->CreateFMul(get_real(lhs_value), get_real(rhs_value)),
        b_->CreateFMul(get_imag(lhs_value), get_imag(rhs_value)));
    llvm::Value* imag = b_->CreateFAdd(
        b_->CreateFMul(get_real(lhs_value), get_imag(rhs_value)),
        b_->CreateFMul(get_imag(lhs_value), get_real(rhs_value)));
    result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType());
    result = b_->CreateInsertValue(result, real, {0});
    result = b_->CreateInsertValue(result, imag, {1});
  } else {
    result = b_->CreateFMul(lhs_value, rhs_value);
  }
  target_array_.EmitWriteArrayElement(/*index=*/element_index, result, b_);
  return absl::OkStatus();
}

absl::Status DotOpEmitter::EmitCallToRuntime() {
  if (!allow_runtime_calls_) {
    return Internal(
        "Trying to emit a call to runtime when it was explicitly disabled.");
  }

  // The signature of the Eigen runtime matmul function is:
  //
  //   (void)(void* run_options, float* out, float* lhs, float* rhs,
  //          int64_t m, int64_t n, int64_t k, int32_t transpose_lhs,
  //          int32_t transpose_rhs);
  // The two transpose_... parameters are actually booleans, but we use int32_t
  // to avoid target-dependent calling convention details.

  bool multi_threaded = ShouldUseMultiThreadedEigen(hlo_module_config_);
  bool use_acl = hlo_module_config_.debug_options().xla_cpu_use_acl();
  PrimitiveType type = target_array_.GetShape().element_type();
  llvm::Function* function = b_->GetInsertBlock()->getParent();
  llvm::LLVMContext& context = b_->getContext();
  llvm::Module* module = function->getParent();
  llvm::Type* float_type;
  const char* fn_name;
  switch (type) {
    case F16:
      fn_name = multi_threaded
                    ? runtime::kEigenMatMulF16SymbolName
                    : runtime::kEigenSingleThreadedMatMulF16SymbolName;
      float_type = b_->getHalfTy();
      break;
    case F32:
      fn_name = multi_threaded
                    ? (use_acl ? runtime::kACLMatMulF32SymbolName
                               : runtime::kEigenMatMulF32SymbolName)
                    : runtime::kEigenSingleThreadedMatMulF32SymbolName;
      float_type = b_->getFloatTy();
      break;
    case F64:
      fn_name = multi_threaded
                    ? runtime::kEigenMatMulF64SymbolName
                    : runtime::kEigenSingleThreadedMatMulF64SymbolName;
      float_type = b_->getDoubleTy();
      break;
    case C64:
      fn_name = multi_threaded
                    ? runtime::kEigenMatMulC64SymbolName
                    : runtime::kEigenSingleThreadedMatMulC64SymbolName;
      float_type = llvm_ir::PrimitiveTypeToIrType(C64, context);
      break;
    case C128:
      fn_name = multi_threaded
                    ? runtime::kEigenMatMulC128SymbolName
                    : runtime::kEigenSingleThreadedMatMulC128SymbolName;
      float_type = llvm_ir::PrimitiveTypeToIrType(C128, context);
      break;
    case S32:
      fn_name = multi_threaded
                    ? runtime::kEigenMatMulS32SymbolName
                    : runtime::kEigenSingleThreadedMatMulS32SymbolName;
      float_type = b_->getInt32Ty();
      break;
    default:
      return Unimplemented("Invalid type %s for dot operation",
                           PrimitiveType_Name(type));
  }

  llvm::Type* ptr_type = b_->getPtrTy();
  llvm::Type* int64_type = b_->getInt64Ty();
  llvm::Type* int32_type = b_->getInt32Ty();
  llvm::FunctionType* matmul_type = llvm::FunctionType::get(
      b_->getVoidTy(),
      {ptr_type, ptr_type, ptr_type, ptr_type, int64_type, int64_type,
       int64_type, int32_type, int32_type},
      /*isVarArg=*/false);

  llvm::FunctionCallee matmul_func =
      module->getOrInsertFunction(fn_name, matmul_type);
  if (auto* fn = llvm::dyn_cast<llvm::Function>(matmul_func.getCallee())) {
    fn->setCallingConv(llvm::CallingConv::C);
    fn->setDoesNotThrow();
    fn->setOnlyAccessesArgMemory();
  }

  // The Eigen runtime function expects column-major layout. If the matrices are
  // row major, then use the following identity to compute the product:
  //
  //   (A x B)^T = B^T x A^T
  //
  // The connection between this identity and memory layout is that the
  // transpose operation can also be considered as an operation that changes the
  // memory layout of a matrix from row-major to column-major or vice versa.
  //
  // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.

  MatMultDims mat_mult_dims = GetMatMultDims();

  CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major);

  const llvm_ir::IrArray* lhs = &lhs_array_;
  const llvm_ir::IrArray* rhs = &rhs_array_;
  bool transpose_lhs = !mat_mult_dims.lhs_canonical;
  bool transpose_rhs = !mat_mult_dims.rhs_canonical;

  if (!mat_mult_dims.lhs_column_major) {
    std::swap(mat_mult_dims.m, mat_mult_dims.n);
    std::swap(lhs, rhs);
    std::swap(transpose_lhs, transpose_rhs);
  }

  b_->CreateCall(matmul_func,
                 {executable_run_options_value_, target_array_.GetBasePointer(),
                  lhs->GetBasePointer(), rhs->GetBasePointer(),
                  b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n),
                  b_->getInt64(mat_mult_dims.k), b_->getInt32(transpose_lhs),
                  b_->getInt32(transpose_rhs)});
  return absl::OkStatus();
}

absl::Status DotOpEmitter::EmitCallToBatchRuntime() {
  if (!allow_runtime_calls_) {
    return Internal(
        "Trying to emit a call to runtime when it was explicitly disabled.");
  }

  // The signature of the runtime batch matmul function is:
  //
  //   (void)(void* run_options, float* out, float* lhs, float* rhs,
  //          int64_t m, int64_t n, int64_t k, int64_t batch_size, int32_t
  //          transpose_lhs, int32_t transpose_rhs);
  // The two transpose_... parameters are actually booleans, but we use int32_t
  // to avoid target-dependent calling convention details.

  PrimitiveType type = target_array_.GetShape().element_type();
  bool use_acl = hlo_module_config_.debug_options().xla_cpu_use_acl();
  llvm::Function* function = b_->GetInsertBlock()->getParent();
  llvm::Module* module = function->getParent();
  llvm::Type* float_type;
  const char* fn_name;
  switch (type) {
    case F32:
      fn_name = use_acl ? runtime::kACLBatchMatMulF32SymbolName
                        : runtime::kEigenBatchMatMulF32SymbolName;

      float_type = b_->getFloatTy();
      break;
    default:
      return Unimplemented("Invalid type %s for dot operation",
                           PrimitiveType_Name(type));
  }

  llvm::Type* ptr_type = b_->getPtrTy();
  llvm::Type* int64_type = b_->getInt64Ty();
  llvm::Type* int32_type = b_->getInt32Ty();
  llvm::FunctionType* matmul_type = llvm::FunctionType::get(
      b_->getVoidTy(),
      {ptr_type, ptr_type, ptr_type, ptr_type, int64_type, int64_type,
       int64_type, int64_type, int32_type, int32_type},
      /*isVarArg=*/false);

  llvm::FunctionCallee matmul_func =
      module->getOrInsertFunction(fn_name, matmul_type);
  if (auto* fn = llvm::dyn_cast<llvm::Function>(matmul_func.getCallee())) {
    fn->setCallingConv(llvm::CallingConv::C);
    fn->setDoesNotThrow();
    fn->setOnlyAccessesArgMemory();
  }

  // The ACL runtime function expects column-major layout. If the matrices are
  // row major, then use the following identity to compute the product:
  //
  //   (A x B)^T = B^T x A^T
  //
  // The connection between this identity and memory layout is that the
  // transpose operation can also be considered as an operation that changes the
  // memory layout of a matrix from row-major to column-major or vice versa.
  //
  // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.

  MatMultDims mat_mult_dims = GetBatchMatMultDims();
  CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major);

  const llvm_ir::IrArray* lhs = &lhs_array_;
  const llvm_ir::IrArray* rhs = &rhs_array_;
  bool transpose_lhs = !mat_mult_dims.lhs_canonical;
  bool transpose_rhs = !mat_mult_dims.rhs_canonical;
  const Shape& lhs_shape = lhs_array_.GetShape();

  if (!mat_mult_dims.lhs_column_major) {
    std::swap(mat_mult_dims.m, mat_mult_dims.n);
    std::swap(lhs, rhs);
    std::swap(transpose_lhs, transpose_rhs);
  }

  VLOG(1) << "Batch dot emitted with runtime:" << fn_name;

  b_->CreateCall(
      matmul_func,
      {executable_run_options_value_, target_array_.GetBasePointer(),
       lhs->GetBasePointer(), rhs->GetBasePointer(),
       b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n),
       b_->getInt64(mat_mult_dims.k), b_->getInt64(lhs_shape.dimensions(0)),
       b_->getInt32(static_cast<uint32_t>(transpose_lhs)),
       b_->getInt32(static_cast<uint32_t>(transpose_rhs))});
  return absl::OkStatus();
}

DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
  CHECK_LE(dot_info_.result_shape.dimensions().size(), 2);

  const Shape& lhs_shape = lhs_array_.GetShape();
  const Shape& rhs_shape = rhs_array_.GetShape();
  const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;

  auto is_column_major = [](const Shape& shape) {
    return shape.dimensions().size() > 1 &&
           LayoutUtil::Minor(shape.layout(), 0) == 0;
  };

  // Non-contracting dots should never make it here.
  CHECK_GE(dim_nums.lhs_contracting_dimensions_size(), 0);
  CHECK_GE(dim_nums.rhs_contracting_dimensions_size(), 0);

  return {
      /*m=*/lhs_shape.dimensions().size() <= 1
          ? 1LL
          : lhs_shape.dimensions(1LL - dim_nums.lhs_contracting_dimensions(0)),
      /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)),
      /*n=*/rhs_shape.dimensions().size() <= 1
          ? 1LL
          : rhs_shape.dimensions(1LL - dim_nums.rhs_contracting_dimensions(0)),
      /*lhs_column_major=*/is_column_major(lhs_shape),
      /*lhs_canonical=*/lhs_shape.dimensions().size() <= 1 ||
          dim_nums.lhs_contracting_dimensions(0) == 1,
      /*rhs_column_major=*/is_column_major(rhs_shape),
      /*rhs_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 0};
}

DotOpEmitter::MatMultDims DotOpEmitter::GetBatchMatMultDims() const {
  CHECK_LE(dot_info_.result_shape.dimensions().size(), 2);

  const Shape& lhs_shape = lhs_array_.GetShape();
  const Shape& rhs_shape = rhs_array_.GetShape();
  const DotDimensionNumbers& dim_nums = dot_info_.dim_nums;

  auto is_column_major = [](const Shape& shape) {
    return shape.dimensions().size() > 1 &&
           LayoutUtil::Minor(shape.layout(), 0) == 0;
  };

  // Non-contracting dots should never make it here.
  CHECK_GE(dim_nums.lhs_contracting_dimensions_size(), 0);
  CHECK_GE(dim_nums.rhs_contracting_dimensions_size(), 0);

  return {
      /*m=*/lhs_shape.dimensions().size() <= 1
          ? 1LL
          : lhs_shape.dimensions(2LL - dim_nums.lhs_contracting_dimensions(0)),
      /*k=*/lhs_shape.dimensions(1LL + dim_nums.lhs_contracting_dimensions(0)),
      /*n=*/rhs_shape.dimensions().size() <= 1
          ? 1LL
          : rhs_shape.dimensions(2LL - dim_nums.rhs_contracting_dimensions(0)),
      /*lhs_column_major=*/is_column_major(lhs_shape),
      /*lhs_canonical=*/lhs_shape.dimensions().size() <= 1 ||
          dim_nums.lhs_contracting_dimensions(0) == 1,
      /*rhs_column_major=*/is_column_major(rhs_shape),
      /*rhs_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 0};
}

// For vector-matrix dot products, it is always profitable to make the Rhs
// column major.
std::optional<int64_t> ProfitableToMakeDotOperandColumnMajor(
    const HloInstruction& hlo) {
  if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions().size() <= 1) {
    if (hlo.operand(0)->shape().dimensions().size() != 1 ||
        hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) != 0) {
      return {};
    }

    // Don't bother if the other operand is tiny, switching to column major
    // wouldn't use tiling.
    constexpr int kColumnMajorThresholdInBytes = 32;
    int64_t lhs_size =
        ShapeUtil::ByteSizeOfPrimitiveType(hlo.shape().element_type()) *
        ShapeUtil::ElementsIn(hlo.operand(0)->shape());
    if (lhs_size < kColumnMajorThresholdInBytes) {
      return {};
    }

    return 1;
  }

  if (hlo.IsOutputFusion()) {
    auto* fusion_root =
        hlo.fused_instructions_computation()->root_instruction();
    if (fusion_root->opcode() != HloOpcode::kAdd) {
      return {};
    }

    for (auto* fusion_root_op : fusion_root->operands()) {
      if (fusion_root_op->opcode() != HloOpcode::kDot) {
        continue;
      }
      if (auto operand_num =
              ProfitableToMakeDotOperandColumnMajor(*fusion_root_op)) {
        auto* operand = fusion_root_op->operand(*operand_num);
        if (operand->opcode() == HloOpcode::kParameter &&
            operand->user_count() == 1) {
          return operand->parameter_number();
        }
      }
    }
  }

  return {};
}

namespace {

absl::StatusOr<DotOpWorkGroupDim> EmitNonBatchDotOperation(
    DotInfo dot_info, std::string hlo_name,
    const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
    const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array,
    llvm::Value* work_group_id, llvm::Value* executable_run_options_value,
    llvm::IRBuilderBase* b, const HloModuleConfig& hlo_module_config,
    const TargetMachineFeatures& target_machine_features,
    bool allow_runtime_calls, bool allow_parallelism) {
  PrimitiveType type = target_array.GetShape().element_type();
  TF_RET_CHECK(PRED == type || S8 == type || U8 == type || S16 == type ||
               U16 == type || S32 == type || U32 == type || S64 == type ||
               U64 == type || F16 == type || F32 == type || F64 == type ||
               C64 == type || C128 == type);
  DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name),
                           target_array, lhs_array, rhs_array, addend_array,
                           work_group_id, executable_run_options_value, b,
                           hlo_module_config, target_machine_features,
                           allow_runtime_calls, allow_parallelism);

  TF_ASSIGN_OR_RETURN(uint64_t x, dot_emitter.Emit());
  return DotOpWorkGroupDim{x};
}

Shape DropFirstDim(const Shape& shape) {
  absl::Span<int64_t const> array_shape_dims(shape.dimensions());
  array_shape_dims.remove_prefix(1);
  return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
                                                  array_shape_dims);
}

Shape CollapseFirstNDims(const Shape& shape, int64_t n) {
  absl::Span<int64_t const> input_shape_dims(shape.dimensions());
  int64_t prefix_dim =
      std::accumulate(input_shape_dims.begin(), input_shape_dims.begin() + n,
                      1ll, std::multiplies<int64_t>());
  DimensionVector result_dims;
  result_dims.push_back(prefix_dim);
  std::copy(input_shape_dims.begin() + n, input_shape_dims.end(),
            std::back_inserter(result_dims));
  return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
                                                  result_dims);
}

llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilderBase* b,
                                    const llvm_ir::IrArray& array, int64_t n) {
  const Shape& shape = array.GetShape();
  CHECK(shape.has_layout() &&
        LayoutUtil::IsMonotonicWithDim0Major(shape.layout()));
  CHECK_GE(shape.dimensions().size(), n);
  Shape new_shape = CollapseFirstNDims(shape, n);
  llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, b->getContext());
  return llvm_ir::IrArray(array.GetBasePointer(), new_ir_type,
                          std::move(new_shape));
}

absl::Status ValidateDotDimensionNumbers(
    const DotDimensionNumbers& dim_numbers) {
  // Checks some invariants that do not hold in general, but DotDecomposer
  // should have established for us.  This is just a debugging aid.
  TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1);
  std::vector<int64_t> batch_dim_numbers(
      dim_numbers.lhs_batch_dimensions_size());
  absl::c_iota(batch_dim_numbers, 0);
  TF_RET_CHECK(
      absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions()));
  TF_RET_CHECK(
      absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions()));
  return absl::OkStatus();
}

// Slice out the inner array at batch index `batch_index` from `outer_array`.
llvm_ir::IrArray SliceOutInnerArray(llvm_ir::IrArray outer_array,
                                    llvm::Value* batch_index,
                                    llvm::IRBuilderBase* b) {
  Shape inner_shape = DropFirstDim(outer_array.GetShape());
  std::vector<llvm::Value*> multidim_index(inner_shape.dimensions().size() + 1,
                                           b->getInt64(0));
  multidim_index[0] = batch_index;
  llvm_ir::IrArray::Index slice_index(multidim_index, outer_array.GetShape(),
                                      batch_index->getType());
  llvm::Value* slice_ptr = outer_array.EmitArrayElementAddress(slice_index, b);
  llvm::Type* new_ir_type =
      llvm_ir::ShapeToIrType(inner_shape, b->getContext());
  return llvm_ir::IrArray(slice_ptr, new_ir_type, std::move(inner_shape));
}

bool PotentiallyImplementedAsEigenMatmul(
    const HloInstruction& dot, const llvm_ir::IrArray& target_array,
    const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
    llvm::Value* executable_run_options_value, llvm::IRBuilderBase* b,
    const HloModuleConfig& hlo_module_config,
    const TargetMachineFeatures& target_machine_features, DotInfo& dot_info,
    bool allow_runtime_calls) {
  int64_t num_batch_dims =
      dot.dot_dimension_numbers().lhs_batch_dimensions_size();

  // TODO(kramerb): Remove this limitation.
  if (num_batch_dims > 1) return false;

  // First reshape the inputs to make sure we only have one batch dimension.
  // This is a no-op bitcast because the operands have to be in row-major layout
  // (enforced in CpuLayoutAssignment), and the batch dimensions are the leading
  // dimensions (established by DotDecomposer and checked by
  // ValidateDotDimensionNumbers above).
  llvm_ir::IrArray lhs_array_reshaped =
      CollapseFirstNDims(b, lhs_array, num_batch_dims);
  llvm_ir::IrArray rhs_array_reshaped =
      CollapseFirstNDims(b, rhs_array, num_batch_dims);
  llvm_ir::IrArray target_array_reshaped =
      CollapseFirstNDims(b, target_array, num_batch_dims);

  DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers();
  adjusted_dim_numbers.clear_lhs_batch_dimensions();
  adjusted_dim_numbers.clear_rhs_batch_dimensions();

  // Create a DotInfo representing the batch of "inner" dot operations.
  dot_info.lhs_shape = DropFirstDim(lhs_array_reshaped.GetShape());
  dot_info.rhs_shape = DropFirstDim(rhs_array_reshaped.GetShape());
  dot_info.result_shape = DropFirstDim(target_array_reshaped.GetShape());
  dot_info.dim_nums = dot.dot_dimension_numbers();
  dot_info.dim_nums.clear_lhs_batch_dimensions();
  dot_info.dim_nums.clear_rhs_batch_dimensions();

  dot_info.dim_nums.set_lhs_contracting_dimensions(
      0, dot_info.dim_nums.lhs_contracting_dimensions(0) - num_batch_dims);
  dot_info.dim_nums.set_rhs_contracting_dimensions(
      0, dot_info.dim_nums.rhs_contracting_dimensions(0) - num_batch_dims);

  PrimitiveType type = target_array.GetShape().element_type();
  if (F32 != type) return false;

  if (ShapeUtil::IsScalar(dot_info.lhs_shape) ||
      ShapeUtil::IsScalar(dot_info.rhs_shape)) {
    // If the operands are scalar, don't emit any loops.
    return false;
  }

  DotImplementationStrategy impl_strategy =
      GetNonBatchDotImplementationStrategy(dot.GetModule()->config(), dot_info,
                                           target_machine_features,
                                           allow_runtime_calls);

  return impl_strategy == DotImplementationStrategy::kEigen;
}

absl::StatusOr<DotOpWorkGroupDim> EmitBatchDotOperation(
    const HloInstruction& dot, const llvm_ir::IrArray& target_array,
    const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
    DotOpWorkGroupId work_group_id, llvm::Value* executable_run_options_value,
    llvm::IRBuilderBase* b, const HloModuleConfig& hlo_module_config,
    const TargetMachineFeatures& target_machine_features,
    bool allow_runtime_calls, bool allow_parallelism) {
  TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers()));

  // first check if the batch can be rendered directly by the runtime
  // otherwise lower it to a sequence of non-batch dot operations
  DotInfo dot_info;
  if (ShouldUseMultiThreadedEigen(hlo_module_config) &&
      PotentiallyImplementedAsEigenMatmul(
          dot, target_array, lhs_array, rhs_array, executable_run_options_value,
          b, hlo_module_config, target_machine_features, dot_info,
          allow_runtime_calls)) {
    DotOpEmitter dot_emitter(dot_info, std::string(dot.name()), target_array,
                             lhs_array, rhs_array, nullptr /*addend_array*/,
                             work_group_id.x, executable_run_options_value, b,
                             hlo_module_config, target_machine_features,
                             allow_runtime_calls, allow_parallelism);

    TF_RETURN_IF_ERROR(dot_emitter.EmitBatch());
    return DotOpWorkGroupDim{1, 1};

  } else {
    // Lower a batch dot into a (parallel) sequence of non-batch dot operations.

    int64_t num_batch_dims =
        dot.dot_dimension_numbers().lhs_batch_dimensions_size();

    // First reshape the inputs to make sure we only have one batch dimension.
    // This is a no-op bitcast because the operands have to be in row-major
    // layout (enforced in CpuLayoutAssignment), and the batch dimensions are
    // the leading dimensions (established by DotDecomposer and checked by
    // ValidateDotDimensionNumbers above).
    llvm_ir::IrArray lhs_array_reshaped =
        CollapseFirstNDims(b, lhs_array, num_batch_dims);
    llvm_ir::IrArray rhs_array_reshaped =
        CollapseFirstNDims(b, rhs_array, num_batch_dims);
    llvm_ir::IrArray target_array_reshaped =
        CollapseFirstNDims(b, target_array, num_batch_dims);

    int64_t batch_count = lhs_array_reshaped.GetShape().dimensions(0);

    VLOG(2) << "Emitting batch dot operation: batch_count=" << batch_count;

    KernelSupportLibrary ksl(b);

    // Emit the inner non-batch dot operation.
    auto inner_dot = [&](llvm::Value* batch_index) {
      DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers();
      adjusted_dim_numbers.clear_lhs_batch_dimensions();
      adjusted_dim_numbers.clear_rhs_batch_dimensions();

      // Create a DotInfo representing the "inner" non-batch dot operation.
      DotInfo dot_info;
      dot_info.lhs_shape = DropFirstDim(lhs_array_reshaped.GetShape());
      dot_info.rhs_shape = DropFirstDim(rhs_array_reshaped.GetShape());
      dot_info.result_shape = DropFirstDim(target_array_reshaped.GetShape());
      dot_info.dim_nums = dot.dot_dimension_numbers();
      dot_info.dim_nums.clear_lhs_batch_dimensions();
      dot_info.dim_nums.clear_rhs_batch_dimensions();

      dot_info.dim_nums.set_lhs_contracting_dimensions(
          0, dot_info.dim_nums.lhs_contracting_dimensions(0) - num_batch_dims);
      dot_info.dim_nums.set_rhs_contracting_dimensions(
          0, dot_info.dim_nums.rhs_contracting_dimensions(0) - num_batch_dims);

      llvm_ir::IrArray lhs_slice =
          SliceOutInnerArray(lhs_array_reshaped, batch_index, b);
      llvm_ir::IrArray rhs_slice =
          SliceOutInnerArray(rhs_array_reshaped, batch_index, b);
      llvm_ir::IrArray target_slice =
          SliceOutInnerArray(target_array_reshaped, batch_index, b);

      return EmitNonBatchDotOperation(
          dot_info, std::string(dot.name()), target_slice, lhs_slice, rhs_slice,
          nullptr, work_group_id.y, executable_run_options_value, b,
          hlo_module_config, target_machine_features, allow_runtime_calls,
          allow_parallelism);
    };

    int64_t lhs_size =
        ShapeUtil::ElementsIn(DropFirstDim(lhs_array_reshaped.GetShape()));
    int64_t rhs_size =
        ShapeUtil::ElementsIn(DropFirstDim(rhs_array_reshaped.GetShape()));

    // If inner dot is big enough and we have a work group id, use parallel
    // loop to parallelize the batch dimension. Threshold picked randomly based
    // on micro-benchmarks and needs more tuning.
    static constexpr int64_t kParallelLoopThreshold = 32768;
    if (allow_parallelism && (lhs_size > kParallelLoopThreshold ||
                              rhs_size > kParallelLoopThreshold)) {
      TF_ASSIGN_OR_RETURN(auto inner_dims, inner_dot(work_group_id.x));
      DCHECK_EQ(inner_dims.y, 1);
      return DotOpWorkGroupDim{static_cast<uint64_t>(batch_count),
                               inner_dims.x};
    }

    // Emit sequential loop over the batch dimension, but still might decide to
    // parallelize the inner loop.
    DotOpWorkGroupDim inner_dims;
    TF_RETURN_IF_ERROR(ksl.ForWithStatus(
        llvm_ir::IrName(&dot, "bdot"), /*start=*/0, /*end=*/batch_count,
        /*step=*/1, [&](llvm::Value* indvar) {
          TF_ASSIGN_OR_RETURN(inner_dims, inner_dot(indvar));
          return absl::OkStatus();
        }));
    return DotOpWorkGroupDim{1, inner_dims.x};
  }
}

}  // namespace

bool IsBatchDot(const HloInstruction& instr) {
  if (auto* dot_instr = DynCast<HloDotInstruction>(&instr)) {
    return dot_instr->dot_dimension_numbers().lhs_batch_dimensions_size() > 0;
  }
  return false;
}

bool IsBatchDot(const DotInfo& dot_info) {
  return dot_info.dim_nums.lhs_batch_dimensions_size() > 0;
}

DotInfo InnerDotInfo(const DotInfo& batch_dot) {
  DCHECK(IsBatchDot(batch_dot)) << "DotInfo must be a batch dot";

  DotInfo inner_dot;

  inner_dot.lhs_shape = ShapeUtil::DeleteDimensions(
      batch_dot.dim_nums.lhs_batch_dimensions(), batch_dot.lhs_shape);
  inner_dot.rhs_shape = ShapeUtil::DeleteDimensions(
      batch_dot.dim_nums.rhs_batch_dimensions(), batch_dot.rhs_shape);
  inner_dot.result_shape = ShapeUtil::DeleteDimensions(
      batch_dot.dim_nums.lhs_batch_dimensions(), batch_dot.result_shape);

  inner_dot.dim_nums = batch_dot.dim_nums;
  inner_dot.dim_nums.clear_lhs_batch_dimensions();
  inner_dot.dim_nums.clear_rhs_batch_dimensions();

  DCHECK_EQ(batch_dot.dim_nums.lhs_contracting_dimensions_size(), 1);
  DCHECK_EQ(batch_dot.dim_nums.rhs_contracting_dimensions_size(), 1);

  int64_t num_batch_dims = batch_dot.dim_nums.lhs_batch_dimensions_size();
  inner_dot.dim_nums.set_lhs_contracting_dimensions(
      0, inner_dot.dim_nums.lhs_contracting_dimensions(0) - num_batch_dims);
  inner_dot.dim_nums.set_rhs_contracting_dimensions(
      0, inner_dot.dim_nums.rhs_contracting_dimensions(0) - num_batch_dims);

  return inner_dot;
}

DotImplementationStrategy GetDotImplementationStrategy(
    const HloModuleConfig& config, const HloInstruction& instr,
    const TargetMachineFeatures& target_machine_features,
    bool allow_runtime_calls) {
  DotInfo dot_info(instr);
  return GetNonBatchDotImplementationStrategy(
      config, IsBatchDot(dot_info) ? InnerDotInfo(dot_info) : dot_info,
      target_machine_features, allow_runtime_calls);
}

bool DotImplementationCanHandleTranspose(
    const HloInstruction& dot_instr,
    const TargetMachineFeatures& target_machine_features,
    bool allow_runtime_calls) {
  DotInfo dot_info(dot_instr);

  DotImplementationStrategy impl_strategy =
      GetNonBatchDotImplementationStrategy(
          dot_instr.GetModule()->config(),
          IsBatchDot(dot_info) ? InnerDotInfo(dot_info) : dot_info,
          target_machine_features, allow_runtime_calls);

  return impl_strategy == DotImplementationStrategy::kNaiveLlvmIr ||
         impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemv ||
         impl_strategy == DotImplementationStrategy::kEigen;
}

bool DotOperandsAndResultMustHaveRowMajorLayout(
    const HloInstruction& dot_instr,
    const TargetMachineFeatures& target_machine_features,
    bool allow_runtime_calls) {
  // Batched dots require the batch dimensions to be major. DotDecomposer always
  // moves batch dimensions to the front of the shape, so force a row-major
  // layout.
  if (IsBatchDot(dot_instr)) {
    return true;
  }

  DotImplementationStrategy impl_strategy =
      GetNonBatchDotImplementationStrategy(
          dot_instr.GetModule()->config(), DotInfo(dot_instr),
          target_machine_features, allow_runtime_calls);

  return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm ||
         impl_strategy == DotImplementationStrategy::kEigen;
}

absl::StatusOr<DotOpWorkGroupDim> EmitDotOperation(
    const HloInstruction& dot, const llvm_ir::IrArray& target_array,
    const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
    const llvm_ir::IrArray* addend_array, DotOpWorkGroupId work_group_id,
    llvm::Value* executable_run_options_value, llvm::IRBuilderBase* b,
    const HloModuleConfig& hlo_module_config,
    const TargetMachineFeatures& target_machine_features,
    bool allow_runtime_calls, bool allow_parallelism) {
  // This routine assumes that the dot operation is not in a parallelized
  // enclosing computation.
  CHECK(dot.parent()
            ->root_instruction()
            ->backend_config<BackendConfig>()
            ->outer_dimension_partitions()
            .empty());

  if (IsBatchDot(dot)) {
    TF_RET_CHECK(addend_array == nullptr);
    return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array,
                                 work_group_id, executable_run_options_value, b,
                                 hlo_module_config, target_machine_features,
                                 allow_runtime_calls, allow_parallelism);
  }

  return EmitNonBatchDotOperation(
      DotInfo(dot), std::string(dot.name()), target_array, lhs_array, rhs_array,
      addend_array, work_group_id.x, executable_run_options_value, b,
      hlo_module_config, target_machine_features, allow_runtime_calls,
      allow_parallelism);
}

}  // namespace cpu
}  // namespace xla
