上一篇刚说完 JDK 中的线程池,这次尝试着自己来定义一个线程池
参考 ThreadPoolExecutor
的实现可以知道,要自定义一个线程池,总少不了以下一些基本元素:
核心线程数 CoreSize
任务队列 TaskQueue
拒绝策略 RejectPolicy
任务超时时间 Timeout
超时时间单位 Timeunit
下面分步骤实现.
一. 自定义任务队列 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 class BlockingQueue <T > { private Deque<T> queue; private ReentrantLock lock; private Condition fullWaitSignal; private Condition emptyWaitSignal; private int capacity; public BlockingQueue (int capacity) { queue = new ArrayDeque<>(); lock = new ReentrantLock(); fullWaitSignal = lock.newCondition(); emptyWaitSignal = lock.newCondition(); this .capacity = capacity; } public T poll (long timeout, TimeUnit timeUnit) { lock.lock(); try { long nanos = timeUnit.toNanos(timeout); while (queue.isEmpty()) { if (nanos <= 0 ) { return null ; } try { nanos = emptyWaitSignal.awaitNanos(nanos); } catch (InterruptedException e) { e.printStackTrace(); } } T t = queue.removeFirst(); System.out.printf("task [%s] removed from queue\n" , t); fullWaitSignal.signal(); return t; } finally { lock.unlock(); } } public T take () { return poll(0 , TimeUnit.NANOSECONDS); } public void put (T t) { lock.lock(); try { while (queue.size() == capacity) { try { fullWaitSignal.await(); } catch (InterruptedException e) { e.printStackTrace(); } } System.out.printf("[put]-add task [%s] to queue\n" , t); queue.addLast(t); emptyWaitSignal.signal(); } finally { lock.unlock(); } } public boolean offer (T t, long timeout, TimeUnit timeUnit) { lock.lock(); try { long nanos = timeUnit.toNanos(timeout); while (queue.size() == capacity) { if (nanos <= 0 ) return false ; try { nanos = fullWaitSignal.awaitNanos(nanos); } catch (InterruptedException e) { e.printStackTrace(); } } System.out.printf("[offer]-task [%s] add into queue\n" , t); queue.addLast(t); emptyWaitSignal.signal(); return true ; } finally { lock.unlock(); } } public boolean tryPut (T t, RejectedPolicy policy) { lock.lock(); try { if (queue.size() == capacity) { policy.reject(t, this ); return false ; } else { System.out.printf("[try put]-task [%s] add into queue\n" , t); queue.addLast(t); emptyWaitSignal.signal(); return true ; } } finally { lock.unlock(); } } public int size () { lock.lock(); try { return queue.size(); } finally { lock.unlock(); } } }
二. 自定义拒绝策略接口 1 2 3 4 5 @FunctionalInterface interface RejectedPolicy <T > { void reject (T task, BlockingQueue<T> queue) ; }
三. 自定义线程池 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 public class CustomizedThreadPool { private int coresize; private BlockingQueue<Runnable> taskQueue; private Set<Worker> workers; private long timeout; private TimeUnit timeUnit; private RejectedPolicy<Runnable> rejectedPolicy; public CustomizedThreadPool (int coresize, int queueSize, long timeout, TimeUnit timeUnit, RejectedPolicy<Runnable> rejectedPolicy) { this .coresize = coresize; this .taskQueue = new BlockingQueue<>(queueSize); workers = new HashSet<>(); this .timeout = timeout; this .timeUnit = timeUnit; this .rejectedPolicy = rejectedPolicy; } public void execute (Runnable runnable) { synchronized (workers) { if (workers.size() < coresize) { Worker worker = new Worker(runnable); workers.add(worker); System.out.printf("worker [%s] created, task [%s]\n" , worker, runnable); worker.start(); } else { taskQueue.tryPut(runnable, rejectedPolicy); } } } class Worker extends Thread { private Runnable task; public Worker (Runnable task) { this .task = task; } @Override public void run () { while (task != null || (task = taskQueue.poll(timeout, timeUnit)) != null ) { try { System.out.printf("running [%s]\n" , task); task.run(); } finally { task = null ; } synchronized (workers) { System.out.printf("remove worker [%s]" , this ); workers.remove(this ); System.out.printf("worker size is:%d\n" , workers.size()); } } } } }
四. 编写测试用例 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 public static void main (String[] args) { CustomizedThreadPool threadPool = new CustomizedThreadPool(4 , 5 , 1L , TimeUnit.SECONDS, (task, queue) -> { task.run(); }); for (int i=0 ; i < 20 ; i++) { int idx = i; threadPool.execute(() -> { try { TimeUnit.MILLISECONDS.sleep(2000L ); } catch (InterruptedException e) { e.printStackTrace(); } System.out.printf("%d\n" , idx); }); } }