//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Identification:
// This step is responsible for finding the patterns that can be lowered to
// complex instructions, and building a graph to represent the complex
// structures. Starting from the "Converging Shuffle" (a shuffle that
// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
// operands are evaluated and identified as "Composite Nodes" (collections of
// instructions that can potentially be lowered to a single complex
// instruction). This is performed by checking the real and imaginary components
// and tracking the data flow for each component while following the operand
// pairs. Validity of each node is expected to be done upon creation, and any
// validation errors should halt traversal and prevent further graph
// construction.
//
// Replacement:
// This step traverses the graph built up by identification, delegating to the
// target to validate and generate the correct intrinsics, and plumbs them
// together connecting each end of the new intrinsics graph to the existing
// use-def chain. This step is assumed to finish successfully, as all
// information is expected to be correct by this point.
//
//
// Internal data structure:
// ComplexDeinterleavingGraph:
// Keeps references to all the valid CompositeNodes formed as part of the
// transformation, and every Instruction contained within said nodes. It also
// holds onto a reference to the root Instruction, and the root node that should
// replace it.
//
// ComplexDeinterleavingCompositeNode:
// A CompositeNode represents a single transformation point; each node should
// transform into a single complex instruction (ignoring vector splitting, which
// would generate more instructions per node). They are identified in a
// depth-first manner, traversing and identifying the operands of each
// instruction in the order they appear in the IR.
// Each node maintains a reference  to its Real and Imaginary instructions,
// as well as any additional instructions that make up the identified operation
// (Internal instructions should only have uses within their containing node).
// A Node also contains the rotation and operation type that it represents.
// Operands contains pointers to other CompositeNodes, acting as the edges in
// the graph. ReplacementValue is the transformed Value* that has been emitted
// to the IR.
//
// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
// ReplacementValue fields of that Node are relevant, where the ReplacementValue
// should be pre-populated.
//
//===----------------------------------------------------------------------===//

#include "llvm/CodeGen/ComplexDeinterleavingPass.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/InitializePasses.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Utils/Local.h"
#include <algorithm>

using namespace llvm;
using namespace PatternMatch;

#define DEBUG_TYPE "complex-deinterleaving"

STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");

static cl::opt<bool> ComplexDeinterleavingEnabled(
    "enable-complex-deinterleaving",
    cl::desc("Enable generation of complex instructions"), cl::init(true),
    cl::Hidden);

/// Checks the given mask, and determines whether said mask is interleaving.
///
/// To be interleaving, a mask must alternate between `i` and `i + (Length /
/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
/// 4x vector interleaving mask would be <0, 2, 1, 3>).
static bool isInterleavingMask(ArrayRef<int> Mask);

/// Checks the given mask, and determines whether said mask is deinterleaving.
///
/// To be deinterleaving, a mask must increment in steps of 2, and either start
/// with 0 or 1.
/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
/// <1, 3, 5, 7>).
static bool isDeinterleavingMask(ArrayRef<int> Mask);

namespace {

class ComplexDeinterleavingLegacyPass : public FunctionPass {
public:
  static char ID;

  ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
      : FunctionPass(ID), TM(TM) {
    initializeComplexDeinterleavingLegacyPassPass(
        *PassRegistry::getPassRegistry());
  }

  StringRef getPassName() const override {
    return "Complex Deinterleaving Pass";
  }

  bool runOnFunction(Function &F) override;
  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.addRequired<TargetLibraryInfoWrapperPass>();
    AU.setPreservesCFG();
  }

private:
  const TargetMachine *TM;
};

class ComplexDeinterleavingGraph;
struct ComplexDeinterleavingCompositeNode {

  ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
                                     Instruction *R, Instruction *I)
      : Operation(Op), Real(R), Imag(I) {}

private:
  friend class ComplexDeinterleavingGraph;
  using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
  using RawNodePtr = ComplexDeinterleavingCompositeNode *;

public:
  ComplexDeinterleavingOperation Operation;
  Instruction *Real;
  Instruction *Imag;

  // Instructions that should only exist within this node, there should be no
  // users of these instructions outside the node. An example of these would be
  // the multiply instructions of a partial multiply operation.
  SmallVector<Instruction *> InternalInstructions;
  ComplexDeinterleavingRotation Rotation;
  SmallVector<RawNodePtr> Operands;
  Value *ReplacementNode = nullptr;

  void addInstruction(Instruction *I) { InternalInstructions.push_back(I); }
  void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }

  bool hasAllInternalUses(SmallPtrSet<Instruction *, 16> &AllInstructions);

  void dump() { dump(dbgs()); }
  void dump(raw_ostream &OS) {
    auto PrintValue = [&](Value *V) {
      if (V) {
        OS << "\"";
        V->print(OS, true);
        OS << "\"\n";
      } else
        OS << "nullptr\n";
    };
    auto PrintNodeRef = [&](RawNodePtr Ptr) {
      if (Ptr)
        OS << Ptr << "\n";
      else
        OS << "nullptr\n";
    };

    OS << "- CompositeNode: " << this << "\n";
    OS << "  Real: ";
    PrintValue(Real);
    OS << "  Imag: ";
    PrintValue(Imag);
    OS << "  ReplacementNode: ";
    PrintValue(ReplacementNode);
    OS << "  Operation: " << (int)Operation << "\n";
    OS << "  Rotation: " << ((int)Rotation * 90) << "\n";
    OS << "  Operands: \n";
    for (const auto &Op : Operands) {
      OS << "    - ";
      PrintNodeRef(Op);
    }
    OS << "  InternalInstructions:\n";
    for (const auto &I : InternalInstructions) {
      OS << "    - \"";
      I->print(OS, true);
      OS << "\"\n";
    }
  }
};

class ComplexDeinterleavingGraph {
public:
  using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
  using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
  explicit ComplexDeinterleavingGraph(const TargetLowering *tl) : TL(tl) {}

private:
  const TargetLowering *TL;
  Instruction *RootValue;
  NodePtr RootNode;
  SmallVector<NodePtr> CompositeNodes;
  SmallPtrSet<Instruction *, 16> AllInstructions;

  NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
                               Instruction *R, Instruction *I) {
    return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
                                                                I);
  }

  NodePtr submitCompositeNode(NodePtr Node) {
    CompositeNodes.push_back(Node);
    AllInstructions.insert(Node->Real);
    AllInstructions.insert(Node->Imag);
    for (auto *I : Node->InternalInstructions)
      AllInstructions.insert(I);
    return Node;
  }

  NodePtr getContainingComposite(Value *R, Value *I) {
    for (const auto &CN : CompositeNodes) {
      if (CN->Real == R && CN->Imag == I)
        return CN;
    }
    return nullptr;
  }

  /// Identifies a complex partial multiply pattern and its rotation, based on
  /// the following patterns
  ///
  ///  0:  r: cr + ar * br
  ///      i: ci + ar * bi
  /// 90:  r: cr - ai * bi
  ///      i: ci + ai * br
  /// 180: r: cr - ar * br
  ///      i: ci - ar * bi
  /// 270: r: cr + ai * bi
  ///      i: ci - ai * br
  NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);

  /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
  /// is partially known from identifyPartialMul, filling in the other half of
  /// the complex pair.
  NodePtr identifyNodeWithImplicitAdd(
      Instruction *I, Instruction *J,
      std::pair<Instruction *, Instruction *> &CommonOperandI);

  /// Identifies a complex add pattern and its rotation, based on the following
  /// patterns.
  ///
  /// 90:  r: ar - bi
  ///      i: ai + br
  /// 270: r: ar + bi
  ///      i: ai - br
  NodePtr identifyAdd(Instruction *Real, Instruction *Imag);

  NodePtr identifyNode(Instruction *I, Instruction *J);

  Value *replaceNode(RawNodePtr Node);

public:
  void dump() { dump(dbgs()); }
  void dump(raw_ostream &OS) {
    for (const auto &Node : CompositeNodes)
      Node->dump(OS);
  }

  /// Returns false if the deinterleaving operation should be cancelled for the
  /// current graph.
  bool identifyNodes(Instruction *RootI);

  /// Perform the actual replacement of the underlying instruction graph.
  /// Returns false if the deinterleaving operation should be cancelled for the
  /// current graph.
  void replaceNodes();
};

class ComplexDeinterleaving {
public:
  ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
      : TL(tl), TLI(tli) {}
  bool runOnFunction(Function &F);

private:
  bool evaluateBasicBlock(BasicBlock *B);

  const TargetLowering *TL = nullptr;
  const TargetLibraryInfo *TLI = nullptr;
};

} // namespace

char ComplexDeinterleavingLegacyPass::ID = 0;

INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
                      "Complex Deinterleaving", false, false)
INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
                    "Complex Deinterleaving", false, false)

PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
                                                 FunctionAnalysisManager &AM) {
  const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
  auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
  if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
    return PreservedAnalyses::all();

  PreservedAnalyses PA;
  PA.preserve<FunctionAnalysisManagerModuleProxy>();
  return PA;
}

FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
  return new ComplexDeinterleavingLegacyPass(TM);
}

bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
  const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
  auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
  return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
}

bool ComplexDeinterleaving::runOnFunction(Function &F) {
  if (!ComplexDeinterleavingEnabled) {
    LLVM_DEBUG(
        dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
    return false;
  }

  if (!TL->isComplexDeinterleavingSupported()) {
    LLVM_DEBUG(
        dbgs() << "Complex deinterleaving has been disabled, target does "
                  "not support lowering of complex number operations.\n");
    return false;
  }

  bool Changed = false;
  for (auto &B : F)
    Changed |= evaluateBasicBlock(&B);

  return Changed;
}

static bool isInterleavingMask(ArrayRef<int> Mask) {
  // If the size is not even, it's not an interleaving mask
  if ((Mask.size() & 1))
    return false;

  int HalfNumElements = Mask.size() / 2;
  for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
    int MaskIdx = Idx * 2;
    if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
      return false;
  }

  return true;
}

static bool isDeinterleavingMask(ArrayRef<int> Mask) {
  int Offset = Mask[0];
  int HalfNumElements = Mask.size() / 2;

  for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
    if (Mask[Idx] != (Idx * 2) + Offset)
      return false;
  }

  return true;
}

bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
  bool Changed = false;

  SmallVector<Instruction *> DeadInstrRoots;

  for (auto &I : *B) {
    auto *SVI = dyn_cast<ShuffleVectorInst>(&I);
    if (!SVI)
      continue;

    // Look for a shufflevector that takes separate vectors of the real and
    // imaginary components and recombines them into a single vector.
    if (!isInterleavingMask(SVI->getShuffleMask()))
      continue;

    ComplexDeinterleavingGraph Graph(TL);
    if (!Graph.identifyNodes(SVI))
      continue;

    Graph.replaceNodes();
    DeadInstrRoots.push_back(SVI);
    Changed = true;
  }

  for (const auto &I : DeadInstrRoots) {
    if (!I || I->getParent() == nullptr)
      continue;
    llvm::RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
  }

  return Changed;
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
    Instruction *Real, Instruction *Imag,
    std::pair<Instruction *, Instruction *> &PartialMatch) {
  LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
                    << "\n");

  if (!Real->hasOneUse() || !Imag->hasOneUse()) {
    LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
    return nullptr;
  }

  if (Real->getOpcode() != Instruction::FMul ||
      Imag->getOpcode() != Instruction::FMul) {
    LLVM_DEBUG(dbgs() << "  - Real or imaginary instruction is not fmul\n");
    return nullptr;
  }

  Instruction *R0 = dyn_cast<Instruction>(Real->getOperand(0));
  Instruction *R1 = dyn_cast<Instruction>(Real->getOperand(1));
  Instruction *I0 = dyn_cast<Instruction>(Imag->getOperand(0));
  Instruction *I1 = dyn_cast<Instruction>(Imag->getOperand(1));
  if (!R0 || !R1 || !I0 || !I1) {
    LLVM_DEBUG(dbgs() << "  - Mul operand not Instruction\n");
    return nullptr;
  }

  // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
  // rotations and use the operand.
  unsigned Negs = 0;
  SmallVector<Instruction *> FNegs;
  if (R0->getOpcode() == Instruction::FNeg ||
      R1->getOpcode() == Instruction::FNeg) {
    Negs |= 1;
    if (R0->getOpcode() == Instruction::FNeg) {
      FNegs.push_back(R0);
      R0 = dyn_cast<Instruction>(R0->getOperand(0));
    } else {
      FNegs.push_back(R1);
      R1 = dyn_cast<Instruction>(R1->getOperand(0));
    }
    if (!R0 || !R1)
      return nullptr;
  }
  if (I0->getOpcode() == Instruction::FNeg ||
      I1->getOpcode() == Instruction::FNeg) {
    Negs |= 2;
    Negs ^= 1;
    if (I0->getOpcode() == Instruction::FNeg) {
      FNegs.push_back(I0);
      I0 = dyn_cast<Instruction>(I0->getOperand(0));
    } else {
      FNegs.push_back(I1);
      I1 = dyn_cast<Instruction>(I1->getOperand(0));
    }
    if (!I0 || !I1)
      return nullptr;
  }

  ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;

  Instruction *CommonOperand;
  Instruction *UncommonRealOp;
  Instruction *UncommonImagOp;

  if (R0 == I0 || R0 == I1) {
    CommonOperand = R0;
    UncommonRealOp = R1;
  } else if (R1 == I0 || R1 == I1) {
    CommonOperand = R1;
    UncommonRealOp = R0;
  } else {
    LLVM_DEBUG(dbgs() << "  - No equal operand\n");
    return nullptr;
  }

  UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
  if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
      Rotation == ComplexDeinterleavingRotation::Rotation_270)
    std::swap(UncommonRealOp, UncommonImagOp);

  // Between identifyPartialMul and here we need to have found a complete valid
  // pair from the CommonOperand of each part.
  if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
      Rotation == ComplexDeinterleavingRotation::Rotation_180)
    PartialMatch.first = CommonOperand;
  else
    PartialMatch.second = CommonOperand;

  if (!PartialMatch.first || !PartialMatch.second) {
    LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
    return nullptr;
  }

  NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
  if (!CommonNode) {
    LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
    return nullptr;
  }

  NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
  if (!UncommonNode) {
    LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
    return nullptr;
  }

  NodePtr Node = prepareCompositeNode(
      ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
  Node->Rotation = Rotation;
  Node->addOperand(CommonNode);
  Node->addOperand(UncommonNode);
  Node->InternalInstructions.append(FNegs);
  return submitCompositeNode(Node);
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
                                               Instruction *Imag) {
  LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
                    << "\n");
  // Determine rotation
  ComplexDeinterleavingRotation Rotation;
  if (Real->getOpcode() == Instruction::FAdd &&
      Imag->getOpcode() == Instruction::FAdd)
    Rotation = ComplexDeinterleavingRotation::Rotation_0;
  else if (Real->getOpcode() == Instruction::FSub &&
           Imag->getOpcode() == Instruction::FAdd)
    Rotation = ComplexDeinterleavingRotation::Rotation_90;
  else if (Real->getOpcode() == Instruction::FSub &&
           Imag->getOpcode() == Instruction::FSub)
    Rotation = ComplexDeinterleavingRotation::Rotation_180;
  else if (Real->getOpcode() == Instruction::FAdd &&
           Imag->getOpcode() == Instruction::FSub)
    Rotation = ComplexDeinterleavingRotation::Rotation_270;
  else {
    LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
    return nullptr;
  }

  if (!Real->getFastMathFlags().allowContract() ||
      !Imag->getFastMathFlags().allowContract()) {
    LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
    return nullptr;
  }

  Value *CR = Real->getOperand(0);
  Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
  if (!RealMulI)
    return nullptr;
  Value *CI = Imag->getOperand(0);
  Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
  if (!ImagMulI)
    return nullptr;

  if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
    LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
    return nullptr;
  }

  Instruction *R0 = dyn_cast<Instruction>(RealMulI->getOperand(0));
  Instruction *R1 = dyn_cast<Instruction>(RealMulI->getOperand(1));
  Instruction *I0 = dyn_cast<Instruction>(ImagMulI->getOperand(0));
  Instruction *I1 = dyn_cast<Instruction>(ImagMulI->getOperand(1));
  if (!R0 || !R1 || !I0 || !I1) {
    LLVM_DEBUG(dbgs() << "  - Mul operand not Instruction\n");
    return nullptr;
  }

  Instruction *CommonOperand;
  Instruction *UncommonRealOp;
  Instruction *UncommonImagOp;

  if (R0 == I0 || R0 == I1) {
    CommonOperand = R0;
    UncommonRealOp = R1;
  } else if (R1 == I0 || R1 == I1) {
    CommonOperand = R1;
    UncommonRealOp = R0;
  } else {
    LLVM_DEBUG(dbgs() << "  - No equal operand\n");
    return nullptr;
  }

  UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
  if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
      Rotation == ComplexDeinterleavingRotation::Rotation_270)
    std::swap(UncommonRealOp, UncommonImagOp);

  std::pair<Instruction *, Instruction *> PartialMatch(
      (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
       Rotation == ComplexDeinterleavingRotation::Rotation_180)
          ? CommonOperand
          : nullptr,
      (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
       Rotation == ComplexDeinterleavingRotation::Rotation_270)
          ? CommonOperand
          : nullptr);
  NodePtr CNode = identifyNodeWithImplicitAdd(
      cast<Instruction>(CR), cast<Instruction>(CI), PartialMatch);
  if (!CNode) {
    LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
    return nullptr;
  }

  NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
  if (!UncommonRes) {
    LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
    return nullptr;
  }

  assert(PartialMatch.first && PartialMatch.second);
  NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
  if (!CommonRes) {
    LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
    return nullptr;
  }

  NodePtr Node = prepareCompositeNode(
      ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
  Node->addInstruction(RealMulI);
  Node->addInstruction(ImagMulI);
  Node->Rotation = Rotation;
  Node->addOperand(CommonRes);
  Node->addOperand(UncommonRes);
  Node->addOperand(CNode);
  return submitCompositeNode(Node);
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
  LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");

  // Determine rotation
  ComplexDeinterleavingRotation Rotation;
  if ((Real->getOpcode() == Instruction::FSub &&
       Imag->getOpcode() == Instruction::FAdd) ||
      (Real->getOpcode() == Instruction::Sub &&
       Imag->getOpcode() == Instruction::Add))
    Rotation = ComplexDeinterleavingRotation::Rotation_90;
  else if ((Real->getOpcode() == Instruction::FAdd &&
            Imag->getOpcode() == Instruction::FSub) ||
           (Real->getOpcode() == Instruction::Add &&
            Imag->getOpcode() == Instruction::Sub))
    Rotation = ComplexDeinterleavingRotation::Rotation_270;
  else {
    LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
    return nullptr;
  }

  auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
  auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
  auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
  auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));

  if (!AR || !AI || !BR || !BI) {
    LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
    return nullptr;
  }

  NodePtr ResA = identifyNode(AR, AI);
  if (!ResA) {
    LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
    return nullptr;
  }
  NodePtr ResB = identifyNode(BR, BI);
  if (!ResB) {
    LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
    return nullptr;
  }

  NodePtr Node =
      prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
  Node->Rotation = Rotation;
  Node->addOperand(ResA);
  Node->addOperand(ResB);
  return submitCompositeNode(Node);
}

static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
  unsigned OpcA = A->getOpcode();
  unsigned OpcB = B->getOpcode();

  return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
         (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
         (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
         (OpcA == Instruction::Add && OpcB == Instruction::Sub);
}

static bool isInstructionPairMul(Instruction *A, Instruction *B) {
  auto Pattern =
      m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));

  return match(A, Pattern) && match(B, Pattern);
}

ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
  LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n");
  if (NodePtr CN = getContainingComposite(Real, Imag)) {
    LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
    return CN;
  }

  auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
  auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
  if (RealShuffle && ImagShuffle) {
    Value *RealOp1 = RealShuffle->getOperand(1);
    if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
      LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
      return nullptr;
    }
    Value *ImagOp1 = ImagShuffle->getOperand(1);
    if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
      LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
      return nullptr;
    }

    Value *RealOp0 = RealShuffle->getOperand(0);
    Value *ImagOp0 = ImagShuffle->getOperand(0);

    if (RealOp0 != ImagOp0) {
      LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
      return nullptr;
    }

    ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
    ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
    if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
      LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
      return nullptr;
    }

    if (RealMask[0] != 0 || ImagMask[0] != 1) {
      LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
      return nullptr;
    }

    // Type checking, the shuffle type should be a vector type of the same
    // scalar type, but half the size
    auto CheckType = [&](ShuffleVectorInst *Shuffle) {
      Value *Op = Shuffle->getOperand(0);
      auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
      auto *OpTy = cast<FixedVectorType>(Op->getType());

      if (OpTy->getScalarType() != ShuffleTy->getScalarType())
        return false;
      if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
        return false;

      return true;
    };

    auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
      if (!CheckType(Shuffle))
        return false;

      ArrayRef<int> Mask = Shuffle->getShuffleMask();
      int Last = *Mask.rbegin();

      Value *Op = Shuffle->getOperand(0);
      auto *OpTy = cast<FixedVectorType>(Op->getType());
      int NumElements = OpTy->getNumElements();

      // Ensure that the deinterleaving shuffle only pulls from the first
      // shuffle operand.
      return Last < NumElements;
    };

    if (RealShuffle->getType() != ImagShuffle->getType()) {
      LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
      return nullptr;
    }
    if (!CheckDeinterleavingShuffle(RealShuffle)) {
      LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
      return nullptr;
    }
    if (!CheckDeinterleavingShuffle(ImagShuffle)) {
      LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
      return nullptr;
    }

    NodePtr PlaceholderNode =
        prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Shuffle,
                             RealShuffle, ImagShuffle);
    PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
    return submitCompositeNode(PlaceholderNode);
  }
  if (RealShuffle || ImagShuffle)
    return nullptr;

  auto *VTy = cast<FixedVectorType>(Real->getType());
  auto *NewVTy =
      FixedVectorType::get(VTy->getScalarType(), VTy->getNumElements() * 2);

  if (TL->isComplexDeinterleavingOperationSupported(
          ComplexDeinterleavingOperation::CMulPartial, NewVTy) &&
      isInstructionPairMul(Real, Imag)) {
    return identifyPartialMul(Real, Imag);
  }

  if (TL->isComplexDeinterleavingOperationSupported(
          ComplexDeinterleavingOperation::CAdd, NewVTy) &&
      isInstructionPairAdd(Real, Imag)) {
    return identifyAdd(Real, Imag);
  }

  return nullptr;
}

bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
  Instruction *Real;
  Instruction *Imag;
  if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
    return false;

  RootValue = RootI;
  AllInstructions.insert(RootI);
  RootNode = identifyNode(Real, Imag);

  LLVM_DEBUG({
    Function *F = RootI->getFunction();
    BasicBlock *B = RootI->getParent();
    dbgs() << "Complex deinterleaving graph for " << F->getName()
           << "::" << B->getName() << ".\n";
    dump(dbgs());
    dbgs() << "\n";
  });

  // Check all instructions have internal uses
  for (const auto &Node : CompositeNodes) {
    if (!Node->hasAllInternalUses(AllInstructions)) {
      LLVM_DEBUG(dbgs() << "  - Invalid internal uses\n");
      return false;
    }
  }
  return RootNode != nullptr;
}

Value *ComplexDeinterleavingGraph::replaceNode(
    ComplexDeinterleavingGraph::RawNodePtr Node) {
  if (Node->ReplacementNode)
    return Node->ReplacementNode;

  Value *Input0 = replaceNode(Node->Operands[0]);
  Value *Input1 = replaceNode(Node->Operands[1]);
  Value *Accumulator =
      Node->Operands.size() > 2 ? replaceNode(Node->Operands[2]) : nullptr;

  assert(Input0->getType() == Input1->getType() &&
         "Node inputs need to be of the same type");

  Node->ReplacementNode = TL->createComplexDeinterleavingIR(
      Node->Real, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);

  assert(Node->ReplacementNode && "Target failed to create Intrinsic call.");
  NumComplexTransformations += 1;
  return Node->ReplacementNode;
}

void ComplexDeinterleavingGraph::replaceNodes() {
  Value *R = replaceNode(RootNode.get());
  assert(R && "Unable to find replacement for RootValue");
  RootValue->replaceAllUsesWith(R);
}

bool ComplexDeinterleavingCompositeNode::hasAllInternalUses(
    SmallPtrSet<Instruction *, 16> &AllInstructions) {
  if (Operation == ComplexDeinterleavingOperation::Shuffle)
    return true;

  for (auto *User : Real->users()) {
    if (!AllInstructions.contains(cast<Instruction>(User)))
      return false;
  }
  for (auto *User : Imag->users()) {
    if (!AllInstructions.contains(cast<Instruction>(User)))
      return false;
  }
  for (auto *I : InternalInstructions) {
    for (auto *User : I->users()) {
      if (!AllInstructions.contains(cast<Instruction>(User)))
        return false;
    }
  }
  return true;
}
