]> git.lizzy.rs Git - rust.git/commitdiff
Add intrinsics for float arithmetic with `fast` flag enabled
authorUlrik Sverdrup <bluss@users.noreply.github.com>
Mon, 14 Mar 2016 23:01:12 +0000 (00:01 +0100)
committerUlrik Sverdrup <bluss@users.noreply.github.com>
Fri, 18 Mar 2016 16:31:41 +0000 (17:31 +0100)
`fast` a.k.a UnsafeAlgebra is the flag for enabling all "unsafe"
(according to llvm) float optimizations.

See LangRef for more information http://llvm.org/docs/LangRef.html#fast-math-flags

Providing these operations with less precise associativity rules (for
example) is useful to numerical applications.

For example, the summation loop:

    let sum = 0.;
    for element in data {
        sum += *element;
    }

Using the default floating point semantics, this loop expresses the
floats must be added in a sequence, one after another. This constraint
is usually completely unintended, and it means that no autovectorization
is possible.

src/libcore/intrinsics.rs
src/librustc_llvm/lib.rs
src/librustc_trans/trans/build.rs
src/librustc_trans/trans/builder.rs
src/librustc_trans/trans/intrinsic.rs
src/librustc_typeck/check/intrinsic.rs
src/rustllvm/RustWrapper.cpp
src/test/codegen/float_math.rs [new file with mode: 0644]
src/test/run-pass/float_math.rs [new file with mode: 0644]

index 0417ef84163ab30bccc64d6857b3f55de5af4dcf..69cfd0368d635d95637511b9145fa01ae19a6a7b 100644 (file)
@@ -539,6 +539,32 @@ pub fn volatile_copy_nonoverlapping_memory<T>(dst: *mut T, src: *const T,
     /// Returns the nearest integer to an `f64`. Rounds half-way cases away from zero.
     pub fn roundf64(x: f64) -> f64;
 
+    /// Float addition that allows optimizations based on algebraic rules.
+    /// May assume inputs are finite.
+    #[cfg(not(stage0))]
+    pub fn fadd_fast<T>(a: T, b: T) -> T;
+
+    /// Float subtraction that allows optimizations based on algebraic rules.
+    /// May assume inputs are finite.
+    #[cfg(not(stage0))]
+    pub fn fsub_fast<T>(a: T, b: T) -> T;
+
+    /// Float multiplication that allows optimizations based on algebraic rules.
+    /// May assume inputs are finite.
+    #[cfg(not(stage0))]
+    pub fn fmul_fast<T>(a: T, b: T) -> T;
+
+    /// Float division that allows optimizations based on algebraic rules.
+    /// May assume inputs are finite.
+    #[cfg(not(stage0))]
+    pub fn fdiv_fast<T>(a: T, b: T) -> T;
+
+    /// Float remainder that allows optimizations based on algebraic rules.
+    /// May assume inputs are finite.
+    #[cfg(not(stage0))]
+    pub fn frem_fast<T>(a: T, b: T) -> T;
+
+
     /// Returns the number of bits set in an integer type `T`
     pub fn ctpop<T>(x: T) -> T;
 
index a1dca796d9a42db6a7be6e768c0f33aeb9842a48..c1b909bd877a85bd3d61bcf7601b494a25c454d1 100644 (file)
@@ -1310,6 +1310,7 @@ pub fn LLVMBuildFNeg(B: BuilderRef, V: ValueRef, Name: *const c_char)
                          -> ValueRef;
     pub fn LLVMBuildNot(B: BuilderRef, V: ValueRef, Name: *const c_char)
                         -> ValueRef;
+    pub fn LLVMRustSetHasUnsafeAlgebra(Instr: ValueRef);
 
     /* Memory */
     pub fn LLVMBuildAlloca(B: BuilderRef, Ty: TypeRef, Name: *const c_char)
index 22536f2dc434d3ca9c70af6cc4f02b4882a1d045..53e64c086a8dbf2917a14a1653729c884d2d631a 100644 (file)
@@ -221,6 +221,18 @@ pub fn FAdd(cx: Block,
     B(cx).fadd(lhs, rhs)
 }
 
+pub fn FAddFast(cx: Block,
+                lhs: ValueRef,
+                rhs: ValueRef,
+                debug_loc: DebugLoc)
+            -> ValueRef {
+    if cx.unreachable.get() {
+        return _Undef(lhs);
+    }
+    debug_loc.apply(cx.fcx);
+    B(cx).fadd_fast(lhs, rhs)
+}
+
 pub fn Sub(cx: Block,
            lhs: ValueRef,
            rhs: ValueRef,
@@ -269,6 +281,18 @@ pub fn FSub(cx: Block,
     B(cx).fsub(lhs, rhs)
 }
 
+pub fn FSubFast(cx: Block,
+                lhs: ValueRef,
+                rhs: ValueRef,
+                debug_loc: DebugLoc)
+            -> ValueRef {
+    if cx.unreachable.get() {
+        return _Undef(lhs);
+    }
+    debug_loc.apply(cx.fcx);
+    B(cx).fsub_fast(lhs, rhs)
+}
+
 pub fn Mul(cx: Block,
            lhs: ValueRef,
            rhs: ValueRef,
@@ -317,6 +341,18 @@ pub fn FMul(cx: Block,
     B(cx).fmul(lhs, rhs)
 }
 
+pub fn FMulFast(cx: Block,
+                lhs: ValueRef,
+                rhs: ValueRef,
+                debug_loc: DebugLoc)
+            -> ValueRef {
+    if cx.unreachable.get() {
+        return _Undef(lhs);
+    }
+    debug_loc.apply(cx.fcx);
+    B(cx).fmul_fast(lhs, rhs)
+}
+
 pub fn UDiv(cx: Block,
             lhs: ValueRef,
             rhs: ValueRef,
@@ -365,6 +401,18 @@ pub fn FDiv(cx: Block,
     B(cx).fdiv(lhs, rhs)
 }
 
+pub fn FDivFast(cx: Block,
+                lhs: ValueRef,
+                rhs: ValueRef,
+                debug_loc: DebugLoc)
+            -> ValueRef {
+    if cx.unreachable.get() {
+        return _Undef(lhs);
+    }
+    debug_loc.apply(cx.fcx);
+    B(cx).fdiv_fast(lhs, rhs)
+}
+
 pub fn URem(cx: Block,
             lhs: ValueRef,
             rhs: ValueRef,
@@ -401,6 +449,18 @@ pub fn FRem(cx: Block,
     B(cx).frem(lhs, rhs)
 }
 
+pub fn FRemFast(cx: Block,
+                lhs: ValueRef,
+                rhs: ValueRef,
+                debug_loc: DebugLoc)
+            -> ValueRef {
+    if cx.unreachable.get() {
+        return _Undef(lhs);
+    }
+    debug_loc.apply(cx.fcx);
+    B(cx).frem_fast(lhs, rhs)
+}
+
 pub fn Shl(cx: Block,
            lhs: ValueRef,
            rhs: ValueRef,
index 7f8e8393e8c4e94928d8fd0b01f7b871f2f145d0..869e9212b1cf6faa4075d149e93d5c309be5d496 100644 (file)
@@ -226,6 +226,15 @@ pub fn fadd(&self, lhs: ValueRef, rhs: ValueRef) -> ValueRef {
         }
     }
 
+    pub fn fadd_fast(&self, lhs: ValueRef, rhs: ValueRef) -> ValueRef {
+        self.count_insn("fadd");
+        unsafe {
+            let instr = llvm::LLVMBuildFAdd(self.llbuilder, lhs, rhs, noname());
+            llvm::LLVMRustSetHasUnsafeAlgebra(instr);
+            instr
+        }
+    }
+
     pub fn sub(&self, lhs: ValueRef, rhs: ValueRef) -> ValueRef {
         self.count_insn("sub");
         unsafe {
@@ -254,6 +263,15 @@ pub fn fsub(&self, lhs: ValueRef, rhs: ValueRef) -> ValueRef {
         }
     }
 
+    pub fn fsub_fast(&self, lhs: ValueRef, rhs: ValueRef) -> ValueRef {
+        self.count_insn("sub");
+        unsafe {
+            let instr = llvm::LLVMBuildFSub(self.llbuilder, lhs, rhs, noname());
+            llvm::LLVMRustSetHasUnsafeAlgebra(instr);
+            instr
+        }
+    }
+
     pub fn mul(&self, lhs: ValueRef, rhs: ValueRef) -> ValueRef {
         self.count_insn("mul");
         unsafe {
@@ -282,6 +300,16 @@ pub fn fmul(&self, lhs: ValueRef, rhs: ValueRef) -> ValueRef {
         }
     }
 
+    pub fn fmul_fast(&self, lhs: ValueRef, rhs: ValueRef) -> ValueRef {
+        self.count_insn("fmul");
+        unsafe {
+            let instr = llvm::LLVMBuildFMul(self.llbuilder, lhs, rhs, noname());
+            llvm::LLVMRustSetHasUnsafeAlgebra(instr);
+            instr
+        }
+    }
+
+
     pub fn udiv(&self, lhs: ValueRef, rhs: ValueRef) -> ValueRef {
         self.count_insn("udiv");
         unsafe {
@@ -310,6 +338,15 @@ pub fn fdiv(&self, lhs: ValueRef, rhs: ValueRef) -> ValueRef {
         }
     }
 
+    pub fn fdiv_fast(&self, lhs: ValueRef, rhs: ValueRef) -> ValueRef {
+        self.count_insn("fdiv");
+        unsafe {
+            let instr = llvm::LLVMBuildFDiv(self.llbuilder, lhs, rhs, noname());
+            llvm::LLVMRustSetHasUnsafeAlgebra(instr);
+            instr
+        }
+    }
+
     pub fn urem(&self, lhs: ValueRef, rhs: ValueRef) -> ValueRef {
         self.count_insn("urem");
         unsafe {
@@ -331,6 +368,15 @@ pub fn frem(&self, lhs: ValueRef, rhs: ValueRef) -> ValueRef {
         }
     }
 
+    pub fn frem_fast(&self, lhs: ValueRef, rhs: ValueRef) -> ValueRef {
+        self.count_insn("frem");
+        unsafe {
+            let instr = llvm::LLVMBuildFRem(self.llbuilder, lhs, rhs, noname());
+            llvm::LLVMRustSetHasUnsafeAlgebra(instr);
+            instr
+        }
+    }
+
     pub fn shl(&self, lhs: ValueRef, rhs: ValueRef) -> ValueRef {
         self.count_insn("shl");
         unsafe {
index 43976f8233bcefa5979b7c63b363a1ae0a7d6861..0ad65e5dab4fd7ab43e8c11a6fbd96df61c6c449 100644 (file)
@@ -658,6 +658,29 @@ pub fn trans_intrinsic_call<'a, 'blk, 'tcx>(mut bcx: Block<'blk, 'tcx>,
             }
 
         },
+        (_, "fadd_fast") | (_, "fsub_fast") | (_, "fmul_fast") | (_, "fdiv_fast") |
+        (_, "frem_fast") => {
+            let sty = &arg_tys[0].sty;
+            match float_type_width(sty) {
+                Some(_width) =>
+                    match &*name {
+                        "fadd_fast" => FAddFast(bcx, llargs[0], llargs[1], call_debug_location),
+                        "fsub_fast" => FSubFast(bcx, llargs[0], llargs[1], call_debug_location),
+                        "fmul_fast" => FMulFast(bcx, llargs[0], llargs[1], call_debug_location),
+                        "fdiv_fast" => FDivFast(bcx, llargs[0], llargs[1], call_debug_location),
+                        "frem_fast" => FRemFast(bcx, llargs[0], llargs[1], call_debug_location),
+                        _ => unreachable!(),
+                    },
+                None => {
+                    span_invalid_monomorphization_error(
+                        tcx.sess, span,
+                        &format!("invalid monomorphization of `{}` intrinsic: \
+                                  expected basic float type, found `{}`", name, sty));
+                        C_null(llret_ty)
+                }
+            }
+
+        },
 
 
         (_, "return_address") => {
@@ -1700,3 +1723,17 @@ fn int_type_width_signed<'tcx>(sty: &ty::TypeVariants<'tcx>, ccx: &CrateContext)
         _ => None,
     }
 }
+
+// Returns the width of a float TypeVariant
+// Returns None if the type is not a float
+fn float_type_width<'tcx>(sty: &ty::TypeVariants<'tcx>)
+        -> Option<u64> {
+    use rustc::middle::ty::TyFloat;
+    match *sty {
+        TyFloat(t) => Some(match t {
+            ast::FloatTy::F32 => 32,
+            ast::FloatTy::F64 => 64,
+        }),
+        _ => None,
+    }
+}
index a05329bc4a4029b5636e69476453eb24dc5ac11c..3282d17d3a0c2986ef1895c85c7ebab185d30b6f 100644 (file)
@@ -280,6 +280,8 @@ fn param<'a, 'tcx>(ccx: &CrateCtxt<'a, 'tcx>, n: u32) -> Ty<'tcx> {
 
             "overflowing_add" | "overflowing_sub" | "overflowing_mul" =>
                 (1, vec![param(ccx, 0), param(ccx, 0)], param(ccx, 0)),
+            "fadd_fast" | "fsub_fast" | "fmul_fast" | "fdiv_fast" | "frem_fast" =>
+                (1, vec![param(ccx, 0), param(ccx, 0)], param(ccx, 0)),
 
             "return_address" => (0, vec![], tcx.mk_imm_ptr(tcx.types.u8)),
 
index 91cf4aa1da9b0fadc25438c723ec26a246908d8b..f488a517b23fcc292c80ad29cae66445c1c124e2 100644 (file)
@@ -164,6 +164,11 @@ extern "C" void LLVMRemoveFunctionAttrString(LLVMValueRef fn, unsigned index, co
                                           to_remove));
 }
 
+// enable fpmath flag UnsafeAlgebra
+extern "C" void LLVMRustSetHasUnsafeAlgebra(LLVMValueRef Instr) {
+    unwrap<Instruction>(Instr)->setHasUnsafeAlgebra(true);
+}
+
 extern "C" LLVMValueRef LLVMBuildAtomicLoad(LLVMBuilderRef B,
                                             LLVMValueRef source,
                                             const char* Name,
diff --git a/src/test/codegen/float_math.rs b/src/test/codegen/float_math.rs
new file mode 100644 (file)
index 0000000..bc458d4
--- /dev/null
@@ -0,0 +1,60 @@
+// Copyright 2016 The Rust Project Developers. See the COPYRIGHT
+// file at the top-level directory of this distribution and at
+// http://rust-lang.org/COPYRIGHT.
+//
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+// compile-flags: -C no-prepopulate-passes
+
+#![crate_type = "lib"]
+#![feature(core_intrinsics)]
+
+use std::intrinsics::{fadd_fast, fsub_fast, fmul_fast, fdiv_fast, frem_fast};
+
+// CHECK-LABEL: @add
+#[no_mangle]
+pub fn add(x: f32, y: f32) -> f32 {
+// CHECK: fadd float
+// CHECK-NOT fast
+    x + y
+}
+
+// CHECK-LABEL: @addition
+#[no_mangle]
+pub fn addition(x: f32, y: f32) -> f32 {
+// CHECK: fadd fast float
+    unsafe {
+        fadd_fast(x, y)
+    }
+}
+
+// CHECK-LABEL: @subtraction
+#[no_mangle]
+pub fn subtraction(x: f32, y: f32) -> f32 {
+// CHECK: fsub fast float
+    unsafe {
+        fsub_fast(x, y)
+    }
+}
+
+// CHECK-LABEL: @multiplication
+#[no_mangle]
+pub fn multiplication(x: f32, y: f32) -> f32 {
+// CHECK: fmul fast float
+    unsafe {
+        fmul_fast(x, y)
+    }
+}
+
+// CHECK-LABEL: @division
+#[no_mangle]
+pub fn division(x: f32, y: f32) -> f32 {
+// CHECK: fdiv fast float
+    unsafe {
+        fdiv_fast(x, y)
+    }
+}
diff --git a/src/test/run-pass/float_math.rs b/src/test/run-pass/float_math.rs
new file mode 100644 (file)
index 0000000..396b732
--- /dev/null
@@ -0,0 +1,24 @@
+// Copyright 2016 The Rust Project Developers. See the COPYRIGHT
+// file at the top-level directory of this distribution and at
+// http://rust-lang.org/COPYRIGHT.
+//
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![feature(core_intrinsics)]
+
+use std::intrinsics::{fadd_fast, fsub_fast, fmul_fast, fdiv_fast, frem_fast};
+
+fn main() {
+    // make sure they all map to the correct operation
+    unsafe {
+        assert_eq!(fadd_fast(1., 2.), 1. + 2.);
+        assert_eq!(fsub_fast(1., 2.), 1. - 2.);
+        assert_eq!(fmul_fast(2., 3.), 2. * 3.);
+        assert_eq!(fdiv_fast(10., 5.), 10. / 5.);
+        assert_eq!(frem_fast(10., 5.), 10. % 5.);
+    }
+}