threadpool/
lib.rs

1// Copyright 2014 The Rust Project Developers. See the COPYRIGHT
2// file at the top-level directory of this distribution and at
3// http://rust-lang.org/COPYRIGHT.
4//
5// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8// option. This file may not be copied, modified, or distributed
9// except according to those terms.
10
11//! A thread pool used to execute functions in parallel.
12//!
13//! Spawns a specified number of worker threads and replenishes the pool if any worker threads
14//! panic.
15//!
16//! # Examples
17//!
18//! ## Synchronized with a channel
19//!
20//! Every thread sends one message over the channel, which then is collected with the `take()`.
21//!
22//! ```
23//! use threadpool::ThreadPool;
24//! use std::sync::mpsc::channel;
25//!
26//! let n_workers = 4;
27//! let n_jobs = 8;
28//! let pool = ThreadPool::new(n_workers);
29//!
30//! let (tx, rx) = channel();
31//! for _ in 0..n_jobs {
32//!     let tx = tx.clone();
33//!     pool.execute(move|| {
34//!         tx.send(1).expect("channel will be there waiting for the pool");
35//!     });
36//! }
37//!
38//! assert_eq!(rx.iter().take(n_jobs).fold(0, |a, b| a + b), 8);
39//! ```
40//!
41//! ## Synchronized with a barrier
42//!
43//! Keep in mind, if a barrier synchronizes more jobs than you have workers in the pool,
44//! you will end up with a [deadlock](https://en.wikipedia.org/wiki/Deadlock)
45//! at the barrier which is [not considered unsafe](
46//! https://doc.rust-lang.org/reference/behavior-not-considered-unsafe.html).
47//!
48//! ```
49//! use threadpool::ThreadPool;
50//! use std::sync::{Arc, Barrier};
51//! use std::sync::atomic::{AtomicUsize, Ordering};
52//!
53//! // create at least as many workers as jobs or you will deadlock yourself
54//! let n_workers = 42;
55//! let n_jobs = 23;
56//! let pool = ThreadPool::new(n_workers);
57//! let an_atomic = Arc::new(AtomicUsize::new(0));
58//!
59//! assert!(n_jobs <= n_workers, "too many jobs, will deadlock");
60//!
61//! // create a barrier that waits for all jobs plus the starter thread
62//! let barrier = Arc::new(Barrier::new(n_jobs + 1));
63//! for _ in 0..n_jobs {
64//!     let barrier = barrier.clone();
65//!     let an_atomic = an_atomic.clone();
66//!
67//!     pool.execute(move|| {
68//!         // do the heavy work
69//!         an_atomic.fetch_add(1, Ordering::Relaxed);
70//!
71//!         // then wait for the other threads
72//!         barrier.wait();
73//!     });
74//! }
75//!
76//! // wait for the threads to finish the work
77//! barrier.wait();
78//! assert_eq!(an_atomic.load(Ordering::SeqCst), /* n_jobs = */ 23);
79//! ```
80
81extern crate num_cpus;
82
83use std::fmt;
84use std::sync::atomic::{AtomicUsize, Ordering};
85use std::sync::mpsc::{channel, Receiver, Sender};
86use std::sync::{Arc, Condvar, Mutex};
87use std::thread;
88
89trait FnBox {
90    fn call_box(self: Box<Self>);
91}
92
93impl<F: FnOnce()> FnBox for F {
94    fn call_box(self: Box<F>) {
95        (*self)()
96    }
97}
98
99type Thunk<'a> = Box<FnBox + Send + 'a>;
100
101struct Sentinel<'a> {
102    shared_data: &'a Arc<ThreadPoolSharedData>,
103    active: bool,
104}
105
106impl<'a> Sentinel<'a> {
107    fn new(shared_data: &'a Arc<ThreadPoolSharedData>) -> Sentinel<'a> {
108        Sentinel {
109            shared_data: shared_data,
110            active: true,
111        }
112    }
113
114    /// Cancel and destroy this sentinel.
115    fn cancel(mut self) {
116        self.active = false;
117    }
118}
119
120impl<'a> Drop for Sentinel<'a> {
121    fn drop(&mut self) {
122        if self.active {
123            self.shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
124            if thread::panicking() {
125                self.shared_data.panic_count.fetch_add(1, Ordering::SeqCst);
126            }
127            self.shared_data.no_work_notify_all();
128            spawn_in_pool(self.shared_data.clone())
129        }
130    }
131}
132
133/// [`ThreadPool`] factory, which can be used in order to configure the properties of the
134/// [`ThreadPool`].
135///
136/// The three configuration options available:
137///
138/// * `num_threads`: maximum number of threads that will be alive at any given moment by the built
139///   [`ThreadPool`]
140/// * `thread_name`: thread name for each of the threads spawned by the built [`ThreadPool`]
141/// * `thread_stack_size`: stack size (in bytes) for each of the threads spawned by the built
142///   [`ThreadPool`]
143///
144/// [`ThreadPool`]: struct.ThreadPool.html
145///
146/// # Examples
147///
148/// Build a [`ThreadPool`] that uses a maximum of eight threads simultaneously and each thread has
149/// a 8 MB stack size:
150///
151/// ```
152/// let pool = threadpool::Builder::new()
153///     .num_threads(8)
154///     .thread_stack_size(8_000_000)
155///     .build();
156/// ```
157#[derive(Clone, Default)]
158pub struct Builder {
159    num_threads: Option<usize>,
160    thread_name: Option<String>,
161    thread_stack_size: Option<usize>,
162}
163
164impl Builder {
165    /// Initiate a new [`Builder`].
166    ///
167    /// [`Builder`]: struct.Builder.html
168    ///
169    /// # Examples
170    ///
171    /// ```
172    /// let builder = threadpool::Builder::new();
173    /// ```
174    pub fn new() -> Builder {
175        Builder {
176            num_threads: None,
177            thread_name: None,
178            thread_stack_size: None,
179        }
180    }
181
182    /// Set the maximum number of worker-threads that will be alive at any given moment by the built
183    /// [`ThreadPool`]. If not specified, defaults the number of threads to the number of CPUs.
184    ///
185    /// [`ThreadPool`]: struct.ThreadPool.html
186    ///
187    /// # Panics
188    ///
189    /// This method will panic if `num_threads` is 0.
190    ///
191    /// # Examples
192    ///
193    /// No more than eight threads will be alive simultaneously for this pool:
194    ///
195    /// ```
196    /// use std::thread;
197    ///
198    /// let pool = threadpool::Builder::new()
199    ///     .num_threads(8)
200    ///     .build();
201    ///
202    /// for _ in 0..100 {
203    ///     pool.execute(|| {
204    ///         println!("Hello from a worker thread!")
205    ///     })
206    /// }
207    /// ```
208    pub fn num_threads(mut self, num_threads: usize) -> Builder {
209        assert!(num_threads > 0);
210        self.num_threads = Some(num_threads);
211        self
212    }
213
214    /// Set the thread name for each of the threads spawned by the built [`ThreadPool`]. If not
215    /// specified, threads spawned by the thread pool will be unnamed.
216    ///
217    /// [`ThreadPool`]: struct.ThreadPool.html
218    ///
219    /// # Examples
220    ///
221    /// Each thread spawned by this pool will have the name "foo":
222    ///
223    /// ```
224    /// use std::thread;
225    ///
226    /// let pool = threadpool::Builder::new()
227    ///     .thread_name("foo".into())
228    ///     .build();
229    ///
230    /// for _ in 0..100 {
231    ///     pool.execute(|| {
232    ///         assert_eq!(thread::current().name(), Some("foo"));
233    ///     })
234    /// }
235    /// ```
236    pub fn thread_name(mut self, name: String) -> Builder {
237        self.thread_name = Some(name);
238        self
239    }
240
241    /// Set the stack size (in bytes) for each of the threads spawned by the built [`ThreadPool`].
242    /// If not specified, threads spawned by the threadpool will have a stack size [as specified in
243    /// the `std::thread` documentation][thread].
244    ///
245    /// [thread]: https://doc.rust-lang.org/nightly/std/thread/index.html#stack-size
246    /// [`ThreadPool`]: struct.ThreadPool.html
247    ///
248    /// # Examples
249    ///
250    /// Each thread spawned by this pool will have a 4 MB stack:
251    ///
252    /// ```
253    /// let pool = threadpool::Builder::new()
254    ///     .thread_stack_size(4_000_000)
255    ///     .build();
256    ///
257    /// for _ in 0..100 {
258    ///     pool.execute(|| {
259    ///         println!("This thread has a 4 MB stack size!");
260    ///     })
261    /// }
262    /// ```
263    pub fn thread_stack_size(mut self, size: usize) -> Builder {
264        self.thread_stack_size = Some(size);
265        self
266    }
267
268    /// Finalize the [`Builder`] and build the [`ThreadPool`].
269    ///
270    /// [`Builder`]: struct.Builder.html
271    /// [`ThreadPool`]: struct.ThreadPool.html
272    ///
273    /// # Examples
274    ///
275    /// ```
276    /// let pool = threadpool::Builder::new()
277    ///     .num_threads(8)
278    ///     .thread_stack_size(4_000_000)
279    ///     .build();
280    /// ```
281    pub fn build(self) -> ThreadPool {
282        let (tx, rx) = channel::<Thunk<'static>>();
283
284        let num_threads = self.num_threads.unwrap_or_else(num_cpus::get);
285
286        let shared_data = Arc::new(ThreadPoolSharedData {
287            name: self.thread_name,
288            job_receiver: Mutex::new(rx),
289            empty_condvar: Condvar::new(),
290            empty_trigger: Mutex::new(()),
291            join_generation: AtomicUsize::new(0),
292            queued_count: AtomicUsize::new(0),
293            active_count: AtomicUsize::new(0),
294            max_thread_count: AtomicUsize::new(num_threads),
295            panic_count: AtomicUsize::new(0),
296            stack_size: self.thread_stack_size,
297        });
298
299        // Threadpool threads
300        for _ in 0..num_threads {
301            spawn_in_pool(shared_data.clone());
302        }
303
304        ThreadPool {
305            jobs: tx,
306            shared_data: shared_data,
307        }
308    }
309}
310
311struct ThreadPoolSharedData {
312    name: Option<String>,
313    job_receiver: Mutex<Receiver<Thunk<'static>>>,
314    empty_trigger: Mutex<()>,
315    empty_condvar: Condvar,
316    join_generation: AtomicUsize,
317    queued_count: AtomicUsize,
318    active_count: AtomicUsize,
319    max_thread_count: AtomicUsize,
320    panic_count: AtomicUsize,
321    stack_size: Option<usize>,
322}
323
324impl ThreadPoolSharedData {
325    fn has_work(&self) -> bool {
326        self.queued_count.load(Ordering::SeqCst) > 0 || self.active_count.load(Ordering::SeqCst) > 0
327    }
328
329    /// Notify all observers joining this pool if there is no more work to do.
330    fn no_work_notify_all(&self) {
331        if !self.has_work() {
332            *self
333                .empty_trigger
334                .lock()
335                .expect("Unable to notify all joining threads");
336            self.empty_condvar.notify_all();
337        }
338    }
339}
340
341/// Abstraction of a thread pool for basic parallelism.
342pub struct ThreadPool {
343    // How the threadpool communicates with subthreads.
344    //
345    // This is the only such Sender, so when it is dropped all subthreads will
346    // quit.
347    jobs: Sender<Thunk<'static>>,
348    shared_data: Arc<ThreadPoolSharedData>,
349}
350
351impl ThreadPool {
352    /// Creates a new thread pool capable of executing `num_threads` number of jobs concurrently.
353    ///
354    /// # Panics
355    ///
356    /// This function will panic if `num_threads` is 0.
357    ///
358    /// # Examples
359    ///
360    /// Create a new thread pool capable of executing four jobs concurrently:
361    ///
362    /// ```
363    /// use threadpool::ThreadPool;
364    ///
365    /// let pool = ThreadPool::new(4);
366    /// ```
367    pub fn new(num_threads: usize) -> ThreadPool {
368        Builder::new().num_threads(num_threads).build()
369    }
370
371    /// Creates a new thread pool capable of executing `num_threads` number of jobs concurrently.
372    /// Each thread will have the [name][thread name] `name`.
373    ///
374    /// # Panics
375    ///
376    /// This function will panic if `num_threads` is 0.
377    ///
378    /// # Examples
379    ///
380    /// ```rust
381    /// use std::thread;
382    /// use threadpool::ThreadPool;
383    ///
384    /// let pool = ThreadPool::with_name("worker".into(), 2);
385    /// for _ in 0..2 {
386    ///     pool.execute(|| {
387    ///         assert_eq!(
388    ///             thread::current().name(),
389    ///             Some("worker")
390    ///         );
391    ///     });
392    /// }
393    /// pool.join();
394    /// ```
395    ///
396    /// [thread name]: https://doc.rust-lang.org/std/thread/struct.Thread.html#method.name
397    pub fn with_name(name: String, num_threads: usize) -> ThreadPool {
398        Builder::new()
399            .num_threads(num_threads)
400            .thread_name(name)
401            .build()
402    }
403
404    /// **Deprecated: Use [`ThreadPool::with_name`](#method.with_name)**
405    #[inline(always)]
406    #[deprecated(since = "1.4.0", note = "use ThreadPool::with_name")]
407    pub fn new_with_name(name: String, num_threads: usize) -> ThreadPool {
408        Self::with_name(name, num_threads)
409    }
410
411    /// Executes the function `job` on a thread in the pool.
412    ///
413    /// # Examples
414    ///
415    /// Execute four jobs on a thread pool that can run two jobs concurrently:
416    ///
417    /// ```
418    /// use threadpool::ThreadPool;
419    ///
420    /// let pool = ThreadPool::new(2);
421    /// pool.execute(|| println!("hello"));
422    /// pool.execute(|| println!("world"));
423    /// pool.execute(|| println!("foo"));
424    /// pool.execute(|| println!("bar"));
425    /// pool.join();
426    /// ```
427    pub fn execute<F>(&self, job: F)
428    where
429        F: FnOnce() + Send + 'static,
430    {
431        self.shared_data.queued_count.fetch_add(1, Ordering::SeqCst);
432        self.jobs
433            .send(Box::new(job))
434            .expect("ThreadPool::execute unable to send job into queue.");
435    }
436
437    /// Returns the number of jobs waiting to executed in the pool.
438    ///
439    /// # Examples
440    ///
441    /// ```
442    /// use threadpool::ThreadPool;
443    /// use std::time::Duration;
444    /// use std::thread::sleep;
445    ///
446    /// let pool = ThreadPool::new(2);
447    /// for _ in 0..10 {
448    ///     pool.execute(|| {
449    ///         sleep(Duration::from_secs(100));
450    ///     });
451    /// }
452    ///
453    /// sleep(Duration::from_secs(1)); // wait for threads to start
454    /// assert_eq!(8, pool.queued_count());
455    /// ```
456    pub fn queued_count(&self) -> usize {
457        self.shared_data.queued_count.load(Ordering::Relaxed)
458    }
459
460    /// Returns the number of currently active threads.
461    ///
462    /// # Examples
463    ///
464    /// ```
465    /// use threadpool::ThreadPool;
466    /// use std::time::Duration;
467    /// use std::thread::sleep;
468    ///
469    /// let pool = ThreadPool::new(4);
470    /// for _ in 0..10 {
471    ///     pool.execute(move || {
472    ///         sleep(Duration::from_secs(100));
473    ///     });
474    /// }
475    ///
476    /// sleep(Duration::from_secs(1)); // wait for threads to start
477    /// assert_eq!(4, pool.active_count());
478    /// ```
479    pub fn active_count(&self) -> usize {
480        self.shared_data.active_count.load(Ordering::SeqCst)
481    }
482
483    /// Returns the maximum number of threads the pool will execute concurrently.
484    ///
485    /// # Examples
486    ///
487    /// ```
488    /// use threadpool::ThreadPool;
489    ///
490    /// let mut pool = ThreadPool::new(4);
491    /// assert_eq!(4, pool.max_count());
492    ///
493    /// pool.set_num_threads(8);
494    /// assert_eq!(8, pool.max_count());
495    /// ```
496    pub fn max_count(&self) -> usize {
497        self.shared_data.max_thread_count.load(Ordering::Relaxed)
498    }
499
500    /// Returns the number of panicked threads over the lifetime of the pool.
501    ///
502    /// # Examples
503    ///
504    /// ```
505    /// use threadpool::ThreadPool;
506    ///
507    /// let pool = ThreadPool::new(4);
508    /// for n in 0..10 {
509    ///     pool.execute(move || {
510    ///         // simulate a panic
511    ///         if n % 2 == 0 {
512    ///             panic!()
513    ///         }
514    ///     });
515    /// }
516    /// pool.join();
517    ///
518    /// assert_eq!(5, pool.panic_count());
519    /// ```
520    pub fn panic_count(&self) -> usize {
521        self.shared_data.panic_count.load(Ordering::Relaxed)
522    }
523
524    /// **Deprecated: Use [`ThreadPool::set_num_threads`](#method.set_num_threads)**
525    #[deprecated(since = "1.3.0", note = "use ThreadPool::set_num_threads")]
526    pub fn set_threads(&mut self, num_threads: usize) {
527        self.set_num_threads(num_threads)
528    }
529
530    /// Sets the number of worker-threads to use as `num_threads`.
531    /// Can be used to change the threadpool size during runtime.
532    /// Will not abort already running or waiting threads.
533    ///
534    /// # Panics
535    ///
536    /// This function will panic if `num_threads` is 0.
537    ///
538    /// # Examples
539    ///
540    /// ```
541    /// use threadpool::ThreadPool;
542    /// use std::time::Duration;
543    /// use std::thread::sleep;
544    ///
545    /// let mut pool = ThreadPool::new(4);
546    /// for _ in 0..10 {
547    ///     pool.execute(move || {
548    ///         sleep(Duration::from_secs(100));
549    ///     });
550    /// }
551    ///
552    /// sleep(Duration::from_secs(1)); // wait for threads to start
553    /// assert_eq!(4, pool.active_count());
554    /// assert_eq!(6, pool.queued_count());
555    ///
556    /// // Increase thread capacity of the pool
557    /// pool.set_num_threads(8);
558    ///
559    /// sleep(Duration::from_secs(1)); // wait for new threads to start
560    /// assert_eq!(8, pool.active_count());
561    /// assert_eq!(2, pool.queued_count());
562    ///
563    /// // Decrease thread capacity of the pool
564    /// // No active threads are killed
565    /// pool.set_num_threads(4);
566    ///
567    /// assert_eq!(8, pool.active_count());
568    /// assert_eq!(2, pool.queued_count());
569    /// ```
570    pub fn set_num_threads(&mut self, num_threads: usize) {
571        assert!(num_threads >= 1);
572        let prev_num_threads = self
573            .shared_data
574            .max_thread_count
575            .swap(num_threads, Ordering::Release);
576        if let Some(num_spawn) = num_threads.checked_sub(prev_num_threads) {
577            // Spawn new threads
578            for _ in 0..num_spawn {
579                spawn_in_pool(self.shared_data.clone());
580            }
581        }
582    }
583
584    /// Block the current thread until all jobs in the pool have been executed.
585    ///
586    /// Calling `join` on an empty pool will cause an immediate return.
587    /// `join` may be called from multiple threads concurrently.
588    /// A `join` is an atomic point in time. All threads joining before the join
589    /// event will exit together even if the pool is processing new jobs by the
590    /// time they get scheduled.
591    ///
592    /// Calling `join` from a thread within the pool will cause a deadlock. This
593    /// behavior is considered safe.
594    ///
595    /// # Examples
596    ///
597    /// ```
598    /// use threadpool::ThreadPool;
599    /// use std::sync::Arc;
600    /// use std::sync::atomic::{AtomicUsize, Ordering};
601    ///
602    /// let pool = ThreadPool::new(8);
603    /// let test_count = Arc::new(AtomicUsize::new(0));
604    ///
605    /// for _ in 0..42 {
606    ///     let test_count = test_count.clone();
607    ///     pool.execute(move || {
608    ///         test_count.fetch_add(1, Ordering::Relaxed);
609    ///     });
610    /// }
611    ///
612    /// pool.join();
613    /// assert_eq!(42, test_count.load(Ordering::Relaxed));
614    /// ```
615    pub fn join(&self) {
616        // fast path requires no mutex
617        if self.shared_data.has_work() == false {
618            return ();
619        }
620
621        let generation = self.shared_data.join_generation.load(Ordering::SeqCst);
622        let mut lock = self.shared_data.empty_trigger.lock().unwrap();
623
624        while generation == self.shared_data.join_generation.load(Ordering::Relaxed)
625            && self.shared_data.has_work()
626        {
627            lock = self.shared_data.empty_condvar.wait(lock).unwrap();
628        }
629
630        // increase generation if we are the first thread to come out of the loop
631        self.shared_data.join_generation.compare_and_swap(
632            generation,
633            generation.wrapping_add(1),
634            Ordering::SeqCst,
635        );
636    }
637}
638
639impl Clone for ThreadPool {
640    /// Cloning a pool will create a new handle to the pool.
641    /// The behavior is similar to [Arc](https://doc.rust-lang.org/stable/std/sync/struct.Arc.html).
642    ///
643    /// We could for example submit jobs from multiple threads concurrently.
644    ///
645    /// ```
646    /// use threadpool::ThreadPool;
647    /// use std::thread;
648    /// use std::sync::mpsc::channel;
649    ///
650    /// let pool = ThreadPool::with_name("clone example".into(), 2);
651    ///
652    /// let results = (0..2)
653    ///     .map(|i| {
654    ///         let pool = pool.clone();
655    ///         thread::spawn(move || {
656    ///             let (tx, rx) = channel();
657    ///             for i in 1..12 {
658    ///                 let tx = tx.clone();
659    ///                 pool.execute(move || {
660    ///                     tx.send(i).expect("channel will be waiting");
661    ///                 });
662    ///             }
663    ///             drop(tx);
664    ///             if i == 0 {
665    ///                 rx.iter().fold(0, |accumulator, element| accumulator + element)
666    ///             } else {
667    ///                 rx.iter().fold(1, |accumulator, element| accumulator * element)
668    ///             }
669    ///         })
670    ///     })
671    ///     .map(|join_handle| join_handle.join().expect("collect results from threads"))
672    ///     .collect::<Vec<usize>>();
673    ///
674    /// assert_eq!(vec![66, 39916800], results);
675    /// ```
676    fn clone(&self) -> ThreadPool {
677        ThreadPool {
678            jobs: self.jobs.clone(),
679            shared_data: self.shared_data.clone(),
680        }
681    }
682}
683
684/// Create a thread pool with one thread per CPU.
685/// On machines with hyperthreading,
686/// this will create one thread per hyperthread.
687impl Default for ThreadPool {
688    fn default() -> Self {
689        ThreadPool::new(num_cpus::get())
690    }
691}
692
693impl fmt::Debug for ThreadPool {
694    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
695        f.debug_struct("ThreadPool")
696            .field("name", &self.shared_data.name)
697            .field("queued_count", &self.queued_count())
698            .field("active_count", &self.active_count())
699            .field("max_count", &self.max_count())
700            .finish()
701    }
702}
703
704impl PartialEq for ThreadPool {
705    /// Check if you are working with the same pool
706    ///
707    /// ```
708    /// use threadpool::ThreadPool;
709    ///
710    /// let a = ThreadPool::new(2);
711    /// let b = ThreadPool::new(2);
712    ///
713    /// assert_eq!(a, a);
714    /// assert_eq!(b, b);
715    ///
716    /// # // TODO: change this to assert_ne in the future
717    /// assert!(a != b);
718    /// assert!(b != a);
719    /// ```
720    fn eq(&self, other: &ThreadPool) -> bool {
721        let a: &ThreadPoolSharedData = &*self.shared_data;
722        let b: &ThreadPoolSharedData = &*other.shared_data;
723        a as *const ThreadPoolSharedData == b as *const ThreadPoolSharedData
724        // with rust 1.17 and late:
725        // Arc::ptr_eq(&self.shared_data, &other.shared_data)
726    }
727}
728impl Eq for ThreadPool {}
729
730fn spawn_in_pool(shared_data: Arc<ThreadPoolSharedData>) {
731    let mut builder = thread::Builder::new();
732    if let Some(ref name) = shared_data.name {
733        builder = builder.name(name.clone());
734    }
735    if let Some(ref stack_size) = shared_data.stack_size {
736        builder = builder.stack_size(stack_size.to_owned());
737    }
738    builder
739        .spawn(move || {
740            // Will spawn a new thread on panic unless it is cancelled.
741            let sentinel = Sentinel::new(&shared_data);
742
743            loop {
744                // Shutdown this thread if the pool has become smaller
745                let thread_counter_val = shared_data.active_count.load(Ordering::Acquire);
746                let max_thread_count_val = shared_data.max_thread_count.load(Ordering::Relaxed);
747                if thread_counter_val >= max_thread_count_val {
748                    break;
749                }
750                let message = {
751                    // Only lock jobs for the time it takes
752                    // to get a job, not run it.
753                    let lock = shared_data
754                        .job_receiver
755                        .lock()
756                        .expect("Worker thread unable to lock job_receiver");
757                    lock.recv()
758                };
759
760                let job = match message {
761                    Ok(job) => job,
762                    // The ThreadPool was dropped.
763                    Err(..) => break,
764                };
765                // Do not allow IR around the job execution
766                shared_data.active_count.fetch_add(1, Ordering::SeqCst);
767                shared_data.queued_count.fetch_sub(1, Ordering::SeqCst);
768
769                job.call_box();
770
771                shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
772                shared_data.no_work_notify_all();
773            }
774
775            sentinel.cancel();
776        })
777        .unwrap();
778}
779
780#[cfg(test)]
781mod test {
782    use super::{Builder, ThreadPool};
783    use std::sync::atomic::{AtomicUsize, Ordering};
784    use std::sync::mpsc::{channel, sync_channel};
785    use std::sync::{Arc, Barrier};
786    use std::thread::{self, sleep};
787    use std::time::Duration;
788
789    const TEST_TASKS: usize = 4;
790
791    #[test]
792    fn test_set_num_threads_increasing() {
793        let new_thread_amount = TEST_TASKS + 8;
794        let mut pool = ThreadPool::new(TEST_TASKS);
795        for _ in 0..TEST_TASKS {
796            pool.execute(move || sleep(Duration::from_secs(23)));
797        }
798        sleep(Duration::from_secs(1));
799        assert_eq!(pool.active_count(), TEST_TASKS);
800
801        pool.set_num_threads(new_thread_amount);
802
803        for _ in 0..(new_thread_amount - TEST_TASKS) {
804            pool.execute(move || sleep(Duration::from_secs(23)));
805        }
806        sleep(Duration::from_secs(1));
807        assert_eq!(pool.active_count(), new_thread_amount);
808
809        pool.join();
810    }
811
812    #[test]
813    fn test_set_num_threads_decreasing() {
814        let new_thread_amount = 2;
815        let mut pool = ThreadPool::new(TEST_TASKS);
816        for _ in 0..TEST_TASKS {
817            pool.execute(move || {
818                assert_eq!(1, 1);
819            });
820        }
821        pool.set_num_threads(new_thread_amount);
822        for _ in 0..new_thread_amount {
823            pool.execute(move || sleep(Duration::from_secs(23)));
824        }
825        sleep(Duration::from_secs(1));
826        assert_eq!(pool.active_count(), new_thread_amount);
827
828        pool.join();
829    }
830
831    #[test]
832    fn test_active_count() {
833        let pool = ThreadPool::new(TEST_TASKS);
834        for _ in 0..2 * TEST_TASKS {
835            pool.execute(move || loop {
836                sleep(Duration::from_secs(10))
837            });
838        }
839        sleep(Duration::from_secs(1));
840        let active_count = pool.active_count();
841        assert_eq!(active_count, TEST_TASKS);
842        let initialized_count = pool.max_count();
843        assert_eq!(initialized_count, TEST_TASKS);
844    }
845
846    #[test]
847    fn test_works() {
848        let pool = ThreadPool::new(TEST_TASKS);
849
850        let (tx, rx) = channel();
851        for _ in 0..TEST_TASKS {
852            let tx = tx.clone();
853            pool.execute(move || {
854                tx.send(1).unwrap();
855            });
856        }
857
858        assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS);
859    }
860
861    #[test]
862    #[should_panic]
863    fn test_zero_tasks_panic() {
864        ThreadPool::new(0);
865    }
866
867    #[test]
868    fn test_recovery_from_subtask_panic() {
869        let pool = ThreadPool::new(TEST_TASKS);
870
871        // Panic all the existing threads.
872        for _ in 0..TEST_TASKS {
873            pool.execute(move || panic!("Ignore this panic, it must!"));
874        }
875        pool.join();
876
877        assert_eq!(pool.panic_count(), TEST_TASKS);
878
879        // Ensure new threads were spawned to compensate.
880        let (tx, rx) = channel();
881        for _ in 0..TEST_TASKS {
882            let tx = tx.clone();
883            pool.execute(move || {
884                tx.send(1).unwrap();
885            });
886        }
887
888        assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS);
889    }
890
891    #[test]
892    fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() {
893        let pool = ThreadPool::new(TEST_TASKS);
894        let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));
895
896        // Panic all the existing threads in a bit.
897        for _ in 0..TEST_TASKS {
898            let waiter = waiter.clone();
899            pool.execute(move || {
900                waiter.wait();
901                panic!("Ignore this panic, it should!");
902            });
903        }
904
905        drop(pool);
906
907        // Kick off the failure.
908        waiter.wait();
909    }
910
911    #[test]
912    fn test_massive_task_creation() {
913        let test_tasks = 4_200_000;
914
915        let pool = ThreadPool::new(TEST_TASKS);
916        let b0 = Arc::new(Barrier::new(TEST_TASKS + 1));
917        let b1 = Arc::new(Barrier::new(TEST_TASKS + 1));
918
919        let (tx, rx) = channel();
920
921        for i in 0..test_tasks {
922            let tx = tx.clone();
923            let (b0, b1) = (b0.clone(), b1.clone());
924
925            pool.execute(move || {
926                // Wait until the pool has been filled once.
927                if i < TEST_TASKS {
928                    b0.wait();
929                    // wait so the pool can be measured
930                    b1.wait();
931                }
932
933                tx.send(1).is_ok();
934            });
935        }
936
937        b0.wait();
938        assert_eq!(pool.active_count(), TEST_TASKS);
939        b1.wait();
940
941        assert_eq!(rx.iter().take(test_tasks).fold(0, |a, b| a + b), test_tasks);
942        pool.join();
943
944        let atomic_active_count = pool.active_count();
945        assert!(
946            atomic_active_count == 0,
947            "atomic_active_count: {}",
948            atomic_active_count
949        );
950    }
951
952    #[test]
953    fn test_shrink() {
954        let test_tasks_begin = TEST_TASKS + 2;
955
956        let mut pool = ThreadPool::new(test_tasks_begin);
957        let b0 = Arc::new(Barrier::new(test_tasks_begin + 1));
958        let b1 = Arc::new(Barrier::new(test_tasks_begin + 1));
959
960        for _ in 0..test_tasks_begin {
961            let (b0, b1) = (b0.clone(), b1.clone());
962            pool.execute(move || {
963                b0.wait();
964                b1.wait();
965            });
966        }
967
968        let b2 = Arc::new(Barrier::new(TEST_TASKS + 1));
969        let b3 = Arc::new(Barrier::new(TEST_TASKS + 1));
970
971        for _ in 0..TEST_TASKS {
972            let (b2, b3) = (b2.clone(), b3.clone());
973            pool.execute(move || {
974                b2.wait();
975                b3.wait();
976            });
977        }
978
979        b0.wait();
980        pool.set_num_threads(TEST_TASKS);
981
982        assert_eq!(pool.active_count(), test_tasks_begin);
983        b1.wait();
984
985        b2.wait();
986        assert_eq!(pool.active_count(), TEST_TASKS);
987        b3.wait();
988    }
989
990    #[test]
991    fn test_name() {
992        let name = "test";
993        let mut pool = ThreadPool::with_name(name.to_owned(), 2);
994        let (tx, rx) = sync_channel(0);
995
996        // initial thread should share the name "test"
997        for _ in 0..2 {
998            let tx = tx.clone();
999            pool.execute(move || {
1000                let name = thread::current().name().unwrap().to_owned();
1001                tx.send(name).unwrap();
1002            });
1003        }
1004
1005        // new spawn thread should share the name "test" too.
1006        pool.set_num_threads(3);
1007        let tx_clone = tx.clone();
1008        pool.execute(move || {
1009            let name = thread::current().name().unwrap().to_owned();
1010            tx_clone.send(name).unwrap();
1011            panic!();
1012        });
1013
1014        // recover thread should share the name "test" too.
1015        pool.execute(move || {
1016            let name = thread::current().name().unwrap().to_owned();
1017            tx.send(name).unwrap();
1018        });
1019
1020        for thread_name in rx.iter().take(4) {
1021            assert_eq!(name, thread_name);
1022        }
1023    }
1024
1025    #[test]
1026    fn test_debug() {
1027        let pool = ThreadPool::new(4);
1028        let debug = format!("{:?}", pool);
1029        assert_eq!(
1030            debug,
1031            "ThreadPool { name: None, queued_count: 0, active_count: 0, max_count: 4 }"
1032        );
1033
1034        let pool = ThreadPool::with_name("hello".into(), 4);
1035        let debug = format!("{:?}", pool);
1036        assert_eq!(
1037            debug,
1038            "ThreadPool { name: Some(\"hello\"), queued_count: 0, active_count: 0, max_count: 4 }"
1039        );
1040
1041        let pool = ThreadPool::new(4);
1042        pool.execute(move || sleep(Duration::from_secs(5)));
1043        sleep(Duration::from_secs(1));
1044        let debug = format!("{:?}", pool);
1045        assert_eq!(
1046            debug,
1047            "ThreadPool { name: None, queued_count: 0, active_count: 1, max_count: 4 }"
1048        );
1049    }
1050
1051    #[test]
1052    fn test_repeate_join() {
1053        let pool = ThreadPool::with_name("repeate join test".into(), 8);
1054        let test_count = Arc::new(AtomicUsize::new(0));
1055
1056        for _ in 0..42 {
1057            let test_count = test_count.clone();
1058            pool.execute(move || {
1059                sleep(Duration::from_secs(2));
1060                test_count.fetch_add(1, Ordering::Release);
1061            });
1062        }
1063
1064        println!("{:?}", pool);
1065        pool.join();
1066        assert_eq!(42, test_count.load(Ordering::Acquire));
1067
1068        for _ in 0..42 {
1069            let test_count = test_count.clone();
1070            pool.execute(move || {
1071                sleep(Duration::from_secs(2));
1072                test_count.fetch_add(1, Ordering::Relaxed);
1073            });
1074        }
1075        pool.join();
1076        assert_eq!(84, test_count.load(Ordering::Relaxed));
1077    }
1078
1079    #[test]
1080    fn test_multi_join() {
1081        use std::sync::mpsc::TryRecvError::*;
1082
1083        // Toggle the following lines to debug the deadlock
1084        fn error(_s: String) {
1085            //use ::std::io::Write;
1086            //let stderr = ::std::io::stderr();
1087            //let mut stderr = stderr.lock();
1088            //stderr.write(&_s.as_bytes()).is_ok();
1089        }
1090
1091        let pool0 = ThreadPool::with_name("multi join pool0".into(), 4);
1092        let pool1 = ThreadPool::with_name("multi join pool1".into(), 4);
1093        let (tx, rx) = channel();
1094
1095        for i in 0..8 {
1096            let pool1 = pool1.clone();
1097            let pool0_ = pool0.clone();
1098            let tx = tx.clone();
1099            pool0.execute(move || {
1100                pool1.execute(move || {
1101                    error(format!("p1: {} -=- {:?}\n", i, pool0_));
1102                    pool0_.join();
1103                    error(format!("p1: send({})\n", i));
1104                    tx.send(i).expect("send i from pool1 -> main");
1105                });
1106                error(format!("p0: {}\n", i));
1107            });
1108        }
1109        drop(tx);
1110
1111        assert_eq!(rx.try_recv(), Err(Empty));
1112        error(format!("{:?}\n{:?}\n", pool0, pool1));
1113        pool0.join();
1114        error(format!("pool0.join() complete =-= {:?}", pool1));
1115        pool1.join();
1116        error("pool1.join() complete\n".into());
1117        assert_eq!(
1118            rx.iter().fold(0, |acc, i| acc + i),
1119            0 + 1 + 2 + 3 + 4 + 5 + 6 + 7
1120        );
1121    }
1122
1123    #[test]
1124    fn test_empty_pool() {
1125        // Joining an empty pool must return imminently
1126        let pool = ThreadPool::new(4);
1127
1128        pool.join();
1129
1130        assert!(true);
1131    }
1132
1133    #[test]
1134    fn test_no_fun_or_joy() {
1135        // What happens when you keep adding jobs after a join
1136
1137        fn sleepy_function() {
1138            sleep(Duration::from_secs(6));
1139        }
1140
1141        let pool = ThreadPool::with_name("no fun or joy".into(), 8);
1142
1143        pool.execute(sleepy_function);
1144
1145        let p_t = pool.clone();
1146        thread::spawn(move || {
1147            (0..23).map(|_| p_t.execute(sleepy_function)).count();
1148        });
1149
1150        pool.join();
1151    }
1152
1153    #[test]
1154    fn test_clone() {
1155        let pool = ThreadPool::with_name("clone example".into(), 2);
1156
1157        // This batch of jobs will occupy the pool for some time
1158        for _ in 0..6 {
1159            pool.execute(move || {
1160                sleep(Duration::from_secs(2));
1161            });
1162        }
1163
1164        // The following jobs will be inserted into the pool in a random fashion
1165        let t0 = {
1166            let pool = pool.clone();
1167            thread::spawn(move || {
1168                // wait for the first batch of tasks to finish
1169                pool.join();
1170
1171                let (tx, rx) = channel();
1172                for i in 0..42 {
1173                    let tx = tx.clone();
1174                    pool.execute(move || {
1175                        tx.send(i).expect("channel will be waiting");
1176                    });
1177                }
1178                drop(tx);
1179                rx.iter()
1180                    .fold(0, |accumulator, element| accumulator + element)
1181            })
1182        };
1183        let t1 = {
1184            let pool = pool.clone();
1185            thread::spawn(move || {
1186                // wait for the first batch of tasks to finish
1187                pool.join();
1188
1189                let (tx, rx) = channel();
1190                for i in 1..12 {
1191                    let tx = tx.clone();
1192                    pool.execute(move || {
1193                        tx.send(i).expect("channel will be waiting");
1194                    });
1195                }
1196                drop(tx);
1197                rx.iter()
1198                    .fold(1, |accumulator, element| accumulator * element)
1199            })
1200        };
1201
1202        assert_eq!(
1203            861,
1204            t0.join()
1205                .expect("thread 0 will return after calculating additions",)
1206        );
1207        assert_eq!(
1208            39916800,
1209            t1.join()
1210                .expect("thread 1 will return after calculating multiplications",)
1211        );
1212    }
1213
1214    #[test]
1215    fn test_sync_shared_data() {
1216        fn assert_sync<T: Sync>() {}
1217        assert_sync::<super::ThreadPoolSharedData>();
1218    }
1219
1220    #[test]
1221    fn test_send_shared_data() {
1222        fn assert_send<T: Send>() {}
1223        assert_send::<super::ThreadPoolSharedData>();
1224    }
1225
1226    #[test]
1227    fn test_send() {
1228        fn assert_send<T: Send>() {}
1229        assert_send::<ThreadPool>();
1230    }
1231
1232    #[test]
1233    fn test_cloned_eq() {
1234        let a = ThreadPool::new(2);
1235
1236        assert_eq!(a, a.clone());
1237    }
1238
1239    #[test]
1240    /// The scenario is joining threads should not be stuck once their wave
1241    /// of joins has completed. So once one thread joining on a pool has
1242    /// succeded other threads joining on the same pool must get out even if
1243    /// the thread is used for other jobs while the first group is finishing
1244    /// their join
1245    ///
1246    /// In this example this means the waiting threads will exit the join in
1247    /// groups of four because the waiter pool has four workers.
1248    fn test_join_wavesurfer() {
1249        let n_cycles = 4;
1250        let n_workers = 4;
1251        let (tx, rx) = channel();
1252        let builder = Builder::new()
1253            .num_threads(n_workers)
1254            .thread_name("join wavesurfer".into());
1255        let p_waiter = builder.clone().build();
1256        let p_clock = builder.build();
1257
1258        let barrier = Arc::new(Barrier::new(3));
1259        let wave_clock = Arc::new(AtomicUsize::new(0));
1260        let clock_thread = {
1261            let barrier = barrier.clone();
1262            let wave_clock = wave_clock.clone();
1263            thread::spawn(move || {
1264                barrier.wait();
1265                for wave_num in 0..n_cycles {
1266                    wave_clock.store(wave_num, Ordering::SeqCst);
1267                    sleep(Duration::from_secs(1));
1268                }
1269            })
1270        };
1271
1272        {
1273            let barrier = barrier.clone();
1274            p_clock.execute(move || {
1275                barrier.wait();
1276                // this sleep is for stabilisation on weaker platforms
1277                sleep(Duration::from_millis(100));
1278            });
1279        }
1280
1281        // prepare three waves of jobs
1282        for i in 0..3 * n_workers {
1283            let p_clock = p_clock.clone();
1284            let tx = tx.clone();
1285            let wave_clock = wave_clock.clone();
1286            p_waiter.execute(move || {
1287                let now = wave_clock.load(Ordering::SeqCst);
1288                p_clock.join();
1289                // submit jobs for the second wave
1290                p_clock.execute(|| sleep(Duration::from_secs(1)));
1291                let clock = wave_clock.load(Ordering::SeqCst);
1292                tx.send((now, clock, i)).unwrap();
1293            });
1294        }
1295        println!("all scheduled at {}", wave_clock.load(Ordering::SeqCst));
1296        barrier.wait();
1297
1298        p_clock.join();
1299        //p_waiter.join();
1300
1301        drop(tx);
1302        let mut hist = vec![0; n_cycles];
1303        let mut data = vec![];
1304        for (now, after, i) in rx.iter() {
1305            let mut dur = after - now;
1306            if dur >= n_cycles - 1 {
1307                dur = n_cycles - 1;
1308            }
1309            hist[dur] += 1;
1310
1311            data.push((now, after, i));
1312        }
1313        for (i, n) in hist.iter().enumerate() {
1314            println!(
1315                "\t{}: {} {}",
1316                i,
1317                n,
1318                &*(0..*n).fold("".to_owned(), |s, _| s + "*")
1319            );
1320        }
1321        assert!(data.iter().all(|&(cycle, stop, i)| if i < n_workers {
1322            cycle == stop
1323        } else {
1324            cycle < stop
1325        }));
1326
1327        clock_thread.join().unwrap();
1328    }
1329}