diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 2b57a8dce3de5..77a7059849717 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -264,6 +264,33 @@ computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks, return Result; } +/// Emit an implicit cast to convert \p XRead to type of variable \p V +static llvm::Value *emitImplicitCast(IRBuilder<> &Builder, llvm::Value *XRead, + llvm::Value *V) { + // TODO: Add this functionality to the `AtomicInfo` interface + llvm::Type *XReadType = XRead->getType(); + llvm::Type *VType = V->getType(); + if (llvm::AllocaInst *vAlloca = dyn_cast(V)) + VType = vAlloca->getAllocatedType(); + + if (XReadType->isStructTy() && VType->isStructTy()) + // No need to extract or convert. A direct + // `store` will suffice. + return XRead; + + if (XReadType->isStructTy()) + XRead = Builder.CreateExtractValue(XRead, /*Idxs=*/0); + if (VType->isIntegerTy() && XReadType->isFloatingPointTy()) + XRead = Builder.CreateFPToSI(XRead, VType); + else if (VType->isFloatingPointTy() && XReadType->isIntegerTy()) + XRead = Builder.CreateSIToFP(XRead, VType); + else if (VType->isIntegerTy() && XReadType->isIntegerTy()) + XRead = Builder.CreateIntCast(XRead, VType, true); + else if (VType->isFloatingPointTy() && XReadType->isFloatingPointTy()) + XRead = Builder.CreateFPCast(XRead, VType); + return XRead; +} + /// Make \p Source branch to \p Target. /// /// Handles two situations: @@ -8373,6 +8400,8 @@ OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc, } } checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Read); + if (XRead->getType() != V.Var->getType()) + XRead = emitImplicitCast(Builder, XRead, V.Var); Builder.CreateStore(XRead, V.Var, V.IsVolatile); return Builder.saveIP(); } @@ -8657,6 +8686,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicCapture( return AtomicResult.takeError(); Value *CapturedVal = (IsPostfixUpdate ? AtomicResult->first : AtomicResult->second); + if (CapturedVal->getType() != V.Var->getType()) + CapturedVal = emitImplicitCast(Builder, CapturedVal, V.Var); Builder.CreateStore(CapturedVal, V.Var, V.IsVolatile); checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Capture); diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir index 44e32c3f35f9b..cf73c338d8475 100644 --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -1368,6 +1368,77 @@ llvm.func @omp_atomic_read(%arg0 : !llvm.ptr, %arg1 : !llvm.ptr) -> () { // ----- +// CHECK-LABEL: @omp_atomic_read_implicit_cast +llvm.func @omp_atomic_read_implicit_cast () { +//CHECK: %[[Z:.*]] = alloca float, i64 1, align 4 +//CHECK: %[[Y:.*]] = alloca double, i64 1, align 8 +//CHECK: %[[X:.*]] = alloca [2 x { float, float }], i64 1, align 8 +//CHECK: %[[W:.*]] = alloca i32, i64 1, align 4 +//CHECK: %[[X_ELEMENT:.*]] = getelementptr { float, float }, ptr %3, i64 0 + %0 = llvm.mlir.constant(1 : i64) : i64 + %1 = llvm.alloca %0 x f32 {bindc_name = "z"} : (i64) -> !llvm.ptr + %2 = llvm.mlir.constant(1 : i64) : i64 + %3 = llvm.alloca %2 x f64 {bindc_name = "y"} : (i64) -> !llvm.ptr + %4 = llvm.mlir.constant(1 : i64) : i64 + %5 = llvm.alloca %4 x !llvm.array<2 x struct<(f32, f32)>> {bindc_name = "x"} : (i64) -> !llvm.ptr + %6 = llvm.mlir.constant(1 : i64) : i64 + %7 = llvm.alloca %6 x i32 {bindc_name = "w"} : (i64) -> !llvm.ptr + %8 = llvm.mlir.constant(1 : index) : i64 + %9 = llvm.mlir.constant(2 : index) : i64 + %10 = llvm.mlir.constant(1 : i64) : i64 + %11 = llvm.mlir.constant(0 : i64) : i64 + %12 = llvm.sub %8, %10 overflow : i64 + %13 = llvm.mul %12, %10 overflow : i64 + %14 = llvm.mul %13, %10 overflow : i64 + %15 = llvm.add %14, %11 overflow : i64 + %16 = llvm.mul %10, %9 overflow : i64 + %17 = llvm.getelementptr %5[%15] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(f32, f32)> + +//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = alloca { float, float }, align 8 +//CHECK: call void @__atomic_load(i64 8, ptr %[[X_ELEMENT]], ptr %[[ATOMIC_LOAD_TEMP]], i32 0) +//CHECK: %[[LOAD:.*]] = load { float, float }, ptr %[[ATOMIC_LOAD_TEMP]], align 8 +//CHECK: %[[EXT:.*]] = extractvalue { float, float } %[[LOAD]], 0 +//CHECK: store float %[[EXT]], ptr %[[Y]], align 4 + omp.atomic.read %3 = %17 : !llvm.ptr, !llvm.ptr, !llvm.struct<(f32, f32)> + +//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[Z]] monotonic, align 4 +//CHECK: %[[CAST:.*]] = bitcast i32 %[[ATOMIC_LOAD_TEMP]] to float +//CHECK: %[[LOAD:.*]] = fpext float %[[CAST]] to double +//CHECK: store double %[[LOAD]], ptr %[[Y]], align 8 + omp.atomic.read %3 = %1 : !llvm.ptr, !llvm.ptr, f32 + +//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[W]] monotonic, align 4 +//CHECK: %[[LOAD:.*]] = sitofp i32 %[[ATOMIC_LOAD_TEMP]] to double +//CHECK: store double %[[LOAD]], ptr %[[Y]], align 8 + omp.atomic.read %3 = %7 : !llvm.ptr, !llvm.ptr, i32 + +//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i64, ptr %[[Y]] monotonic, align 4 +//CHECK: %[[CAST:.*]] = bitcast i64 %[[ATOMIC_LOAD_TEMP]] to double +//CHECK: %[[LOAD:.*]] = fptrunc double %[[CAST]] to float +//CHECK: store float %[[LOAD]], ptr %[[Z]], align 4 + omp.atomic.read %1 = %3 : !llvm.ptr, !llvm.ptr, f64 + +//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[W]] monotonic, align 4 +//CHECK: %[[LOAD:.*]] = sitofp i32 %[[ATOMIC_LOAD_TEMP]] to float +//CHECK: store float %[[LOAD]], ptr %[[Z]], align 4 + omp.atomic.read %1 = %7 : !llvm.ptr, !llvm.ptr, i32 + +//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i64, ptr %[[Y]] monotonic, align 4 +//CHECK: %[[CAST:.*]] = bitcast i64 %[[ATOMIC_LOAD_TEMP]] to double +//CHECK: %[[LOAD:.*]] = fptosi double %[[CAST]] to i32 +//CHECK: store i32 %[[LOAD]], ptr %[[W]], align 4 + omp.atomic.read %7 = %3 : !llvm.ptr, !llvm.ptr, f64 + +//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[Z]] monotonic, align 4 +//CHECK: %[[CAST:.*]] = bitcast i32 %[[ATOMIC_LOAD_TEMP]] to float +//CHECK: %[[LOAD:.*]] = fptosi float %[[CAST]] to i32 +//CHECK: store i32 %[[LOAD]], ptr %[[W]], align 4 + omp.atomic.read %7 = %1 : !llvm.ptr, !llvm.ptr, f32 + llvm.return +} + +// ----- + // CHECK-LABEL: @omp_atomic_write // CHECK-SAME: (ptr %[[x:.*]], i32 %[[expr:.*]]) llvm.func @omp_atomic_write(%x: !llvm.ptr, %expr: i32) -> () {