]> git.lizzy.rs Git - rust.git/blob - library/portable-simd/crates/core_simd/examples/spectral_norm.rs
Rollup merge of #106854 - steffahn:drop_linear_arc_rebased, r=Mark-Simulacrum
[rust.git] / library / portable-simd / crates / core_simd / examples / spectral_norm.rs
1 #![feature(portable_simd)]
2
3 use core_simd::simd::*;
4
5 fn a(i: usize, j: usize) -> f64 {
6     ((i + j) * (i + j + 1) / 2 + i + 1) as f64
7 }
8
9 fn mult_av(v: &[f64], out: &mut [f64]) {
10     assert!(v.len() == out.len());
11     assert!(v.len() % 2 == 0);
12
13     for (i, out) in out.iter_mut().enumerate() {
14         let mut sum = f64x2::splat(0.0);
15
16         let mut j = 0;
17         while j < v.len() {
18             let b = f64x2::from_slice(&v[j..]);
19             let a = f64x2::from_array([a(i, j), a(i, j + 1)]);
20             sum += b / a;
21             j += 2
22         }
23         *out = sum.reduce_sum();
24     }
25 }
26
27 fn mult_atv(v: &[f64], out: &mut [f64]) {
28     assert!(v.len() == out.len());
29     assert!(v.len() % 2 == 0);
30
31     for (i, out) in out.iter_mut().enumerate() {
32         let mut sum = f64x2::splat(0.0);
33
34         let mut j = 0;
35         while j < v.len() {
36             let b = f64x2::from_slice(&v[j..]);
37             let a = f64x2::from_array([a(j, i), a(j + 1, i)]);
38             sum += b / a;
39             j += 2
40         }
41         *out = sum.reduce_sum();
42     }
43 }
44
45 fn mult_atav(v: &[f64], out: &mut [f64], tmp: &mut [f64]) {
46     mult_av(v, tmp);
47     mult_atv(tmp, out);
48 }
49
50 pub fn spectral_norm(n: usize) -> f64 {
51     assert!(n % 2 == 0, "only even lengths are accepted");
52
53     let mut u = vec![1.0; n];
54     let mut v = u.clone();
55     let mut tmp = u.clone();
56
57     for _ in 0..10 {
58         mult_atav(&u, &mut v, &mut tmp);
59         mult_atav(&v, &mut u, &mut tmp);
60     }
61     (dot(&u, &v) / dot(&v, &v)).sqrt()
62 }
63
64 fn dot(x: &[f64], y: &[f64]) -> f64 {
65     // This is auto-vectorized:
66     x.iter().zip(y).map(|(&x, &y)| x * y).sum()
67 }
68
69 #[cfg(test)]
70 #[test]
71 fn test() {
72     assert_eq!(format!("{:.9}", spectral_norm(100)), "1.274219991");
73 }
74
75 fn main() {
76     // Empty main to make cargo happy
77 }