#include #include #include #include template struct equal_to_value_pred { T value; equal_to_value_pred(T value) : value(value) {} __host__ __device__ bool operator()(T v) const { return v == value; } }; template struct not_equal_to_value_pred { T value; not_equal_to_value_pred(T value) : value(value) {} __host__ __device__ bool operator()(T v) const { return v != value; } }; template struct less_than_value_pred { T value; less_than_value_pred(T value) : value(value) {} __host__ __device__ bool operator()(T v) const { return v < value; } }; template void TestFindSimple(void) { Vector vec(5); vec[0] = 1; vec[1] = 2; vec[2] = 3; vec[3] = 3; vec[4] = 5; ASSERT_EQUAL(thrust::find(vec.begin(), vec.end(), 0) - vec.begin(), 5); ASSERT_EQUAL(thrust::find(vec.begin(), vec.end(), 1) - vec.begin(), 0); ASSERT_EQUAL(thrust::find(vec.begin(), vec.end(), 2) - vec.begin(), 1); ASSERT_EQUAL(thrust::find(vec.begin(), vec.end(), 3) - vec.begin(), 2); ASSERT_EQUAL(thrust::find(vec.begin(), vec.end(), 4) - vec.begin(), 5); ASSERT_EQUAL(thrust::find(vec.begin(), vec.end(), 5) - vec.begin(), 4); } DECLARE_VECTOR_UNITTEST(TestFindSimple); template InputIterator find(my_system &system, InputIterator first, InputIterator, const T&) { system.validate_dispatch(); return first; } void TestFindDispatchExplicit() { thrust::device_vector vec(1); my_system sys(0); thrust::find(sys, vec.begin(), vec.end(), 0); ASSERT_EQUAL(true, sys.is_valid()); } DECLARE_UNITTEST(TestFindDispatchExplicit); template InputIterator find(my_tag, InputIterator first, InputIterator, const T&) { *first = 13; return first; } void TestFindDispatchImplicit() { thrust::device_vector vec(1); thrust::find(thrust::retag(vec.begin()), thrust::retag(vec.end()), 0); ASSERT_EQUAL(13, vec.front()); } DECLARE_UNITTEST(TestFindDispatchImplicit); template void TestFindIfSimple(void) { typedef typename Vector::value_type T; Vector vec(5); vec[0] = 1; vec[1] = 2; vec[2] = 3; vec[3] = 3; vec[4] = 5; ASSERT_EQUAL(thrust::find_if(vec.begin(), vec.end(), equal_to_value_pred(0)) - vec.begin(), 5); ASSERT_EQUAL(thrust::find_if(vec.begin(), vec.end(), equal_to_value_pred(1)) - vec.begin(), 0); ASSERT_EQUAL(thrust::find_if(vec.begin(), vec.end(), equal_to_value_pred(2)) - vec.begin(), 1); ASSERT_EQUAL(thrust::find_if(vec.begin(), vec.end(), equal_to_value_pred(3)) - vec.begin(), 2); ASSERT_EQUAL(thrust::find_if(vec.begin(), vec.end(), equal_to_value_pred(4)) - vec.begin(), 5); ASSERT_EQUAL(thrust::find_if(vec.begin(), vec.end(), equal_to_value_pred(5)) - vec.begin(), 4); } DECLARE_VECTOR_UNITTEST(TestFindIfSimple); template InputIterator find_if(my_system &system, InputIterator first, InputIterator, Predicate) { system.validate_dispatch(); return first; } void TestFindIfDispatchExplicit() { thrust::device_vector vec(1); my_system sys(0); thrust::find_if(sys, vec.begin(), vec.end(), thrust::identity()); ASSERT_EQUAL(true, sys.is_valid()); } DECLARE_UNITTEST(TestFindIfDispatchExplicit); template InputIterator find_if(my_tag, InputIterator first, InputIterator, Predicate) { *first = 13; return first; } void TestFindIfDispatchImplicit() { thrust::device_vector vec(1); thrust::find_if(thrust::retag(vec.begin()), thrust::retag(vec.end()), thrust::identity()); ASSERT_EQUAL(13, vec.front()); } DECLARE_UNITTEST(TestFindIfDispatchImplicit); template void TestFindIfNotSimple(void) { typedef typename Vector::value_type T; Vector vec(5); vec[0] = 0; vec[1] = 1; vec[2] = 2; vec[3] = 3; vec[4] = 4; ASSERT_EQUAL(0, thrust::find_if_not(vec.begin(), vec.end(), less_than_value_pred(0)) - vec.begin()); ASSERT_EQUAL(1, thrust::find_if_not(vec.begin(), vec.end(), less_than_value_pred(1)) - vec.begin()); ASSERT_EQUAL(2, thrust::find_if_not(vec.begin(), vec.end(), less_than_value_pred(2)) - vec.begin()); ASSERT_EQUAL(3, thrust::find_if_not(vec.begin(), vec.end(), less_than_value_pred(3)) - vec.begin()); ASSERT_EQUAL(4, thrust::find_if_not(vec.begin(), vec.end(), less_than_value_pred(4)) - vec.begin()); ASSERT_EQUAL(5, thrust::find_if_not(vec.begin(), vec.end(), less_than_value_pred(5)) - vec.begin()); } DECLARE_VECTOR_UNITTEST(TestFindIfNotSimple); template InputIterator find_if_not(my_system &system, InputIterator first, InputIterator, Predicate) { system.validate_dispatch(); return first; } void TestFindIfNotDispatchExplicit() { thrust::device_vector vec(1); my_system sys(0); thrust::find_if_not(sys, vec.begin(), vec.end(), thrust::identity()); ASSERT_EQUAL(true, sys.is_valid()); } DECLARE_UNITTEST(TestFindIfNotDispatchExplicit); template InputIterator find_if_not(my_tag, InputIterator first, InputIterator, Predicate) { *first = 13; return first; } void TestFindIfNotDispatchImplicit() { thrust::device_vector vec(1); thrust::find_if_not(thrust::retag(vec.begin()), thrust::retag(vec.end()), thrust::identity()); ASSERT_EQUAL(13, vec.front()); } DECLARE_UNITTEST(TestFindIfNotDispatchImplicit); template struct TestFind { void operator()(const size_t n) { thrust::host_vector h_data = unittest::random_integers(n); thrust::device_vector d_data = h_data; typename thrust::host_vector::iterator h_iter; typename thrust::device_vector::iterator d_iter; h_iter = thrust::find(h_data.begin(), h_data.end(), T(0)); d_iter = thrust::find(d_data.begin(), d_data.end(), T(0)); ASSERT_EQUAL(h_iter - h_data.begin(), d_iter - d_data.begin()); for (size_t i = 1; i < n; i *= 2) { T sample = h_data[i]; h_iter = thrust::find(h_data.begin(), h_data.end(), sample); d_iter = thrust::find(d_data.begin(), d_data.end(), sample); ASSERT_EQUAL(h_iter - h_data.begin(), d_iter - d_data.begin()); } } }; VariableUnitTest TestFindInstance; template struct TestFindIf { void operator()(const size_t n) { thrust::host_vector h_data = unittest::random_integers(n); thrust::device_vector d_data = h_data; typename thrust::host_vector::iterator h_iter; typename thrust::device_vector::iterator d_iter; h_iter = thrust::find_if(h_data.begin(), h_data.end(), equal_to_value_pred(0)); d_iter = thrust::find_if(d_data.begin(), d_data.end(), equal_to_value_pred(0)); ASSERT_EQUAL(h_iter - h_data.begin(), d_iter - d_data.begin()); for (size_t i = 1; i < n; i *= 2) { T sample = h_data[i]; h_iter = thrust::find_if(h_data.begin(), h_data.end(), equal_to_value_pred(sample)); d_iter = thrust::find_if(d_data.begin(), d_data.end(), equal_to_value_pred(sample)); ASSERT_EQUAL(h_iter - h_data.begin(), d_iter - d_data.begin()); } } }; VariableUnitTest TestFindIfInstance; template struct TestFindIfNot { void operator()(const size_t n) { thrust::host_vector h_data = unittest::random_integers(n); thrust::device_vector d_data = h_data; typename thrust::host_vector::iterator h_iter; typename thrust::device_vector::iterator d_iter; h_iter = thrust::find_if_not(h_data.begin(), h_data.end(), not_equal_to_value_pred(0)); d_iter = thrust::find_if_not(d_data.begin(), d_data.end(), not_equal_to_value_pred(0)); ASSERT_EQUAL(h_iter - h_data.begin(), d_iter - d_data.begin()); for (size_t i = 1; i < n; i *= 2) { T sample = h_data[i]; h_iter = thrust::find_if_not(h_data.begin(), h_data.end(), not_equal_to_value_pred(sample)); d_iter = thrust::find_if_not(d_data.begin(), d_data.end(), not_equal_to_value_pred(sample)); ASSERT_EQUAL(h_iter - h_data.begin(), d_iter - d_data.begin()); } } }; VariableUnitTest TestFindIfNotInstance; void TestFindWithBigIndexesHelper(int magnitude) { thrust::counting_iterator begin(1); thrust::counting_iterator end = begin + (1ll << magnitude); ASSERT_EQUAL(thrust::distance(begin, end), 1ll << magnitude); thrust::detail::intmax_t distance_low_value = thrust::distance( begin, thrust::find( thrust::device, begin, end, 17)); thrust::detail::intmax_t distance_high_value = thrust::distance( begin, thrust::find( thrust::device, begin, end, (1ll << magnitude) - 17)); ASSERT_EQUAL(distance_low_value, 16); ASSERT_EQUAL(distance_high_value, (1ll << magnitude) - 18); } void TestFindWithBigIndexes() { TestFindWithBigIndexesHelper(30); TestFindWithBigIndexesHelper(31); TestFindWithBigIndexesHelper(32); TestFindWithBigIndexesHelper(33); } DECLARE_UNITTEST(TestFindWithBigIndexes); namespace { class Weird { int value; public: __host__ __device__ Weird(int val, int) : value(val) {} friend __host__ __device__ bool operator==(int x, Weird y) { return x == y.value; } }; } // end anon namespace void TestFindAsymmetricEquality() { // Regression test for thrust/thrust#1229 thrust::host_vector v(1000); thrust::sequence(v.begin(), v.end()); thrust::device_vector dv(v); auto result = thrust::find(dv.begin(), dv.end(), Weird(333, 0)); ASSERT_EQUAL(*result, 333); ASSERT_EQUAL(result - dv.begin(), 333); } DECLARE_UNITTEST(TestFindAsymmetricEquality);