From 6d223373136933c2a2472bdf4e2c4ed223d682f8 Mon Sep 17 00:00:00 2001 From: Alissa Rao Date: Wed, 24 Feb 2021 09:24:46 -0800 Subject: [PATCH] Implement core::iter::Sum for EnumSet --- enumset/src/lib.rs | 22 +++++++++++++++++++++- enumset/tests/ops.rs | 17 +++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/enumset/src/lib.rs b/enumset/src/lib.rs index 55aab1e..66af9e3 100644 --- a/enumset/src/lib.rs +++ b/enumset/src/lib.rs @@ -81,7 +81,7 @@ use core::cmp::Ordering; use core::fmt; use core::fmt::{Debug, Formatter}; use core::hash::{Hash, Hasher}; -use core::iter::FromIterator; +use core::iter::{FromIterator, Sum}; use core::ops::*; #[doc(hidden)] @@ -462,6 +462,26 @@ impl IntoIterator for EnumSet { self.iter() } } +impl Sum for EnumSet { + fn sum>(iter: I) -> Self { + iter.fold(EnumSet::empty(), |a, v| a | v) + } +} +impl <'a, T: EnumSetType> Sum<&'a EnumSet> for EnumSet { + fn sum>(iter: I) -> Self { + iter.fold(EnumSet::empty(), |a, v| a | *v) + } +} +impl Sum for EnumSet { + fn sum>(iter: I) -> Self { + iter.fold(EnumSet::empty(), |a, v| a | v) + } +} +impl <'a, T: EnumSetType> Sum<&'a T> for EnumSet { + fn sum>(iter: I) -> Self { + iter.fold(EnumSet::empty(), |a, v| a | *v) + } +} impl >> Sub for EnumSet { type Output = Self; diff --git a/enumset/tests/ops.rs b/enumset/tests/ops.rs index aa7019b..44694d9 100644 --- a/enumset/tests/ops.rs +++ b/enumset/tests/ops.rs @@ -323,6 +323,23 @@ macro_rules! test_enum { test_set!(tree_set); } + #[test] + fn sum_test() { + let target = $e::A | $e::B | $e::D | $e::E | $e::G | $e::H; + + let list_a = [$e::A | $e::B, $e::D | $e::E, $e::G | $e::H]; + let sum_a: EnumSet<$e> = list_a.iter().map(|x| *x).sum(); + assert_eq!(target, sum_a); + let sum_b: EnumSet<$e> = list_a.iter().sum(); + assert_eq!(target, sum_b); + + let list_b = [$e::A, $e::B, $e::D, $e::E, $e::G, $e::H]; + let sum_c: EnumSet<$e> = list_b.iter().map(|x| *x).sum(); + assert_eq!(target, sum_c); + let sum_d: EnumSet<$e> = list_b.iter().sum(); + assert_eq!(target, sum_d); + } + #[test] fn check_size() { assert_eq!(::std::mem::size_of::>(), $mem_size); -- 2.44.0