]> git.lizzy.rs Git - rust.git/commitdiff
bench: rewrite nbody for better vectorization
authorCristi Cobzarenco <cristi.cobzarenco@gmail.com>
Wed, 7 Oct 2015 19:45:21 +0000 (20:45 +0100)
committerCristi Cobzarenco <cristi.cobzarenco@gmail.com>
Wed, 7 Oct 2015 20:05:22 +0000 (21:05 +0100)
src/test/bench/shootout-nbody.rs

index 5ba678ce183dd14cd8a3463dbcfb02321b33bba5..64d6a8109888eb50c5cf33882310644609beaab1 100644 (file)
 // OF THE POSSIBILITY OF SUCH DAMAGE.
 
 use std::mem;
+use std::ops::{Add, Sub, Mul};
 
 const PI: f64 = 3.141592653589793;
 const SOLAR_MASS: f64 = 4.0 * PI * PI;
 const YEAR: f64 = 365.24;
 const N_BODIES: usize = 5;
+const N_PAIRS: usize = N_BODIES * (N_BODIES - 1) / 2;
 
-static BODIES: [Planet;N_BODIES] = [
+const BODIES: [Planet; N_BODIES] = [
     // Sun
     Planet {
-        x: 0.0, y: 0.0, z: 0.0,
-        vx: 0.0, vy: 0.0, vz: 0.0,
+        pos: Vec3(0.0, 0.0, 0.0),
+        vel: Vec3(0.0, 0.0, 0.0),
         mass: SOLAR_MASS,
     },
     // Jupiter
     Planet {
-        x: 4.84143144246472090e+00,
-        y: -1.16032004402742839e+00,
-        z: -1.03622044471123109e-01,
-        vx: 1.66007664274403694e-03 * YEAR,
-        vy: 7.69901118419740425e-03 * YEAR,
-        vz: -6.90460016972063023e-05 * YEAR,
+        pos: Vec3(4.84143144246472090e+00,
+                  -1.16032004402742839e+00,
+                  -1.03622044471123109e-01),
+        vel: Vec3(1.66007664274403694e-03 * YEAR,
+                  7.69901118419740425e-03 * YEAR,
+                  -6.90460016972063023e-05 * YEAR),
         mass: 9.54791938424326609e-04 * SOLAR_MASS,
     },
     // Saturn
     Planet {
-        x: 8.34336671824457987e+00,
-        y: 4.12479856412430479e+00,
-        z: -4.03523417114321381e-01,
-        vx: -2.76742510726862411e-03 * YEAR,
-        vy: 4.99852801234917238e-03 * YEAR,
-        vz: 2.30417297573763929e-05 * YEAR,
+        pos: Vec3(8.34336671824457987e+00,
+                  4.12479856412430479e+00,
+                  -4.03523417114321381e-01),
+        vel: Vec3(-2.76742510726862411e-03 * YEAR,
+                  4.99852801234917238e-03 * YEAR,
+                  2.30417297573763929e-05 * YEAR),
         mass: 2.85885980666130812e-04 * SOLAR_MASS,
     },
     // Uranus
     Planet {
-        x: 1.28943695621391310e+01,
-        y: -1.51111514016986312e+01,
-        z: -2.23307578892655734e-01,
-        vx: 2.96460137564761618e-03 * YEAR,
-        vy: 2.37847173959480950e-03 * YEAR,
-        vz: -2.96589568540237556e-05 * YEAR,
+        pos: Vec3(1.28943695621391310e+01,
+                  -1.51111514016986312e+01,
+                  -2.23307578892655734e-01),
+        vel: Vec3(2.96460137564761618e-03 * YEAR,
+                  2.37847173959480950e-03 * YEAR,
+                  -2.96589568540237556e-05 * YEAR),
         mass: 4.36624404335156298e-05 * SOLAR_MASS,
     },
     // Neptune
     Planet {
-        x: 1.53796971148509165e+01,
-        y: -2.59193146099879641e+01,
-        z: 1.79258772950371181e-01,
-        vx: 2.68067772490389322e-03 * YEAR,
-        vy: 1.62824170038242295e-03 * YEAR,
-        vz: -9.51592254519715870e-05 * YEAR,
+        pos: Vec3(1.53796971148509165e+01,
+                  -2.59193146099879641e+01,
+                  1.79258772950371181e-01),
+        vel: Vec3(2.68067772490389322e-03 * YEAR,
+                  1.62824170038242295e-03 * YEAR,
+                  -9.51592254519715870e-05 * YEAR),
         mass: 5.15138902046611451e-05 * SOLAR_MASS,
     },
 ];
 
-#[derive(Copy, Clone)]
+/// A 3d Vector type with oveloaded operators to improve readability.
+#[derive(Clone, Copy)]
+struct Vec3(pub f64, pub f64, pub f64);
+
+impl Vec3 {
+    fn zero() -> Self { Vec3(0.0, 0.0, 0.0) }
+
+    fn norm(&self) -> f64 { self.squared_norm().sqrt() }
+
+    fn squared_norm(&self) -> f64 {
+        self.0 * self.0 + self.1 * self.1 + self.2 * self.2
+    }
+}
+
+impl Add for Vec3 {
+    type Output = Self;
+    fn add(self, rhs: Self) -> Self {
+        Vec3(self.0 + rhs.0, self.1 + rhs.1, self.2 + rhs.2)
+    }
+}
+
+impl Sub for Vec3 {
+    type Output = Self;
+    fn sub(self, rhs: Self) -> Self {
+        Vec3(self.0 - rhs.0, self.1 - rhs.1, self.2 - rhs.2)
+    }
+}
+
+impl Mul<f64> for Vec3 {
+    type Output = Self;
+    fn mul(self, rhs: f64) -> Self {
+        Vec3(self.0 * rhs, self.1 * rhs, self.2 * rhs)
+    }
+}
+
+#[derive(Clone, Copy)]
 struct Planet {
-    x: f64, y: f64, z: f64,
-    vx: f64, vy: f64, vz: f64,
+    pos: Vec3,
+    vel: Vec3,
     mass: f64,
 }
 
-fn advance(bodies: &mut [Planet;N_BODIES], dt: f64, steps: isize) {
-    for _ in 0..steps {
-        let mut b_slice: &mut [_] = bodies;
-        loop {
-            let bi = match shift_mut_ref(&mut b_slice) {
-                Some(bi) => bi,
-                None => break
-            };
-            for bj in &mut *b_slice {
-                let dx = bi.x - bj.x;
-                let dy = bi.y - bj.y;
-                let dz = bi.z - bj.z;
-
-                let d2 = dx * dx + dy * dy + dz * dz;
-                let mag = dt / (d2 * d2.sqrt());
-
-                let massj_mag = bj.mass * mag;
-                bi.vx -= dx * massj_mag;
-                bi.vy -= dy * massj_mag;
-                bi.vz -= dz * massj_mag;
-
-                let massi_mag = bi.mass * mag;
-                bj.vx += dx * massi_mag;
-                bj.vy += dy * massi_mag;
-                bj.vz += dz * massi_mag;
-            }
-            bi.x += dt * bi.vx;
-            bi.y += dt * bi.vy;
-            bi.z += dt * bi.vz;
+/// Computes all pairwise position differences between the planets.
+fn pairwise_diffs(bodies: &[Planet; N_BODIES], diff: &mut [Vec3; N_PAIRS]) {
+    let mut bodies = bodies.iter();
+    let mut diff = diff.iter_mut();
+    while let Some(bi) = bodies.next() {
+        for bj in bodies.clone() {
+            *diff.next().unwrap() = bi.pos - bj.pos;
+        }
+    }
+}
+
+/// Computes the magnitude of the force between each pair of planets.
+fn magnitudes(diff: &[Vec3; N_PAIRS], dt: f64, mag: &mut [f64; N_PAIRS]) {
+    for (mag, diff) in mag.iter_mut().zip(diff.iter()) {
+        let d2 = diff.squared_norm();
+        *mag = dt / (d2 * d2.sqrt());
+    }
+}
+
+/// Updates the velocities of the planets by computing their gravitational
+/// accelerations and performing one step of Euler integration.
+fn update_velocities(bodies: &mut [Planet; N_BODIES], dt: f64,
+                     diff: &mut [Vec3; N_PAIRS], mag: &mut [f64; N_PAIRS]) {
+    pairwise_diffs(bodies, diff);
+    magnitudes(&diff, dt, mag);
+
+    let mut bodies = &mut bodies[..];
+    let mut mag = mag.iter();
+    let mut diff = diff.iter();
+    while let Some(bi) = shift_mut_ref(&mut bodies) {
+        for bj in bodies.iter_mut() {
+            let diff = *diff.next().unwrap();
+            let mag = *mag.next().unwrap();
+            bi.vel = bi.vel - diff * (bj.mass * mag);
+            bj.vel = bj.vel + diff * (bi.mass * mag);
         }
     }
 }
 
-fn energy(bodies: &[Planet;N_BODIES]) -> f64 {
+/// Advances the solar system by one timestep by first updating the
+/// velocities and then integrating the positions using the updated velocities.
+///
+/// Note: the `diff` & `mag` arrays are effectively scratch space. They're
+/// provided as arguments to avoid re-zeroing them every time `advance` is
+/// called.
+fn advance(mut bodies: &mut [Planet; N_BODIES], dt: f64,
+           diff: &mut [Vec3; N_PAIRS], mag: &mut [f64; N_PAIRS]) {
+    update_velocities(bodies, dt, diff, mag);
+    for body in bodies.iter_mut() {
+        body.pos = body.pos + body.vel * dt;
+    }
+}
+
+/// Computes the total energy of the solar system.
+fn energy(bodies: &[Planet; N_BODIES]) -> f64 {
     let mut e = 0.0;
     let mut bodies = bodies.iter();
-    loop {
-        let bi = match bodies.next() {
-            Some(bi) => bi,
-            None => break
-        };
-        e += (bi.vx * bi.vx + bi.vy * bi.vy + bi.vz * bi.vz) * bi.mass / 2.0;
-        for bj in bodies.clone() {
-            let dx = bi.x - bj.x;
-            let dy = bi.y - bj.y;
-            let dz = bi.z - bj.z;
-            let dist = (dx * dx + dy * dy + dz * dz).sqrt();
-            e -= bi.mass * bj.mass / dist;
-        }
+    while let Some(bi) = bodies.next() {
+        e += bi.vel.squared_norm() * bi.mass / 2.0
+           - bi.mass * bodies.clone()
+                             .map(|bj| bj.mass / (bi.pos - bj.pos).norm())
+                             .fold(0.0, |a, b| a + b);
     }
     e
 }
 
-fn offset_momentum(bodies: &mut [Planet;N_BODIES]) {
-    let mut px = 0.0;
-    let mut py = 0.0;
-    let mut pz = 0.0;
-    for bi in bodies.iter() {
-        px += bi.vx * bi.mass;
-        py += bi.vy * bi.mass;
-        pz += bi.vz * bi.mass;
-    }
-    let sun = &mut bodies[0];
-    sun.vx = - px / SOLAR_MASS;
-    sun.vy = - py / SOLAR_MASS;
-    sun.vz = - pz / SOLAR_MASS;
+/// Offsets the sun's velocity to make the overall momentum of the system zero.
+fn offset_momentum(bodies: &mut [Planet; N_BODIES]) {
+    let p = bodies.iter().fold(Vec3::zero(), |v, b| v + b.vel * b.mass);
+    bodies[0].vel = p * (-1.0 / bodies[0].mass);
 }
 
 fn main() {
@@ -178,11 +219,15 @@ fn main() {
             .unwrap_or(1000)
     };
     let mut bodies = BODIES;
+    let mut diff = [Vec3::zero(); N_PAIRS];
+    let mut mag = [0.0f64; N_PAIRS];
 
     offset_momentum(&mut bodies);
     println!("{:.9}", energy(&bodies));
 
-    advance(&mut bodies, 0.01, n);
+    for _ in (0..n) {
+        advance(&mut bodies, 0.01, &mut diff, &mut mag);
+    }
 
     println!("{:.9}", energy(&bodies));
 }