kotlin 线程池封装

时间:2025-03-20 13:13:36

支持任务调度、暂停、取消等操作
*/
class ThreadPoolManager private constructor(builder: Builder) {

// region 配置参数
private val corePoolSize: Int
private val maxPoolSize: Int
private val keepAliveTime: Long
private val workQueue: BlockingQueue
private val threadFactory: ThreadFactory
private val exceptionHandler: ((Thread, Throwable) -> Unit)?
// endregion

// region 核心组件
private val executor: ThreadPoolExecutor
val coroutineDispatcher: CoroutineContext
internal val taskMap = ConcurrentHashMap<String, Future<*>>()
// endregion

init {
corePoolSize = builder.corePoolSize
maxPoolSize = builder.maxPoolSize
keepAliveTime = builder.keepAliveTime
workQueue = builder.workQueue ?: LinkedBlockingQueue()
threadFactory = builder.threadFactory ?: DefaultThreadFactory(builder.exceptionHandler)
exceptionHandler = builder.exceptionHandler

 executor = ThreadPoolExecutor(
     corePoolSize,
     maxPoolSize,
     keepAliveTime,
     TimeUnit.MILLISECONDS,
     workQueue,
     threadFactory
 ).apply {
     allowCoreThreadTimeOut(builder.allowCoreThreadTimeout)
 }

 coroutineDispatcher = executor.asCoroutineDispatcher()

}

// region 预设模板
companion object {
/**
* IO密集型任务配置(适合网络请求)
*/
fun newIOThreadPool(): ThreadPoolManager {
val cpuCount = Runtime.getRuntime().availableProcessors()
return Builder().apply {
corePoolSize = cpuCount * 2
maxPoolSize = cpuCount * 4
keepAliveTime = 30L
workQueue = LinkedBlockingQueue(128)
}.build()
}

 /**
  * CPU密集型任务配置(适合数据库操作)
  */
 fun newCPUThreadPool(): ThreadPoolManager {
     val cpuCount = Runtime.getRuntime().availableProcessors()
     return Builder().apply {
         corePoolSize = cpuCount + 1
         maxPoolSize = cpuCount * 2
         keepAliveTime = 10L
         workQueue = LinkedBlockingQueue(64)
     }.build()
 }

}
// endregion

// region 任务管理
@Synchronized
fun submit(tag: String, task: Runnable): Boolean {
if (taskMap.containsKey(tag)) return false
taskMap[tag] = executor.submit(wrapTask(tag, task))
return true
}

fun submit(tag: String, task: Callable): Future {
val future = executor.submit(wrapTask(tag, task))
taskMap[tag] = future
return future
}

fun pause(tag: String) = taskMap[tag]?.cancel(true)

fun resume(tag: String, task: Runnable) = submit(tag, task)

fun cancel(tag: String) {
taskMap[tag]?.cancel(true)
taskMap.remove(tag)
}

fun shutdown(immediate: Boolean = false) {
if (immediate) {
executor.shutdownNow()
} else {
executor.shutdown()
try {
if (!executor.awaitTermination(30, TimeUnit.SECONDS)) {
executor.shutdownNow()
}
} catch (e: InterruptedException) {
executor.shutdownNow()
Thread.currentThread().interrupt()
}
}
taskMap.clear()
}
// endregion

// region 内部实现
private fun wrapTask(tag: String, task: Runnable): Runnable = Runnable {
try {
task.run()
} finally {
taskMap.remove(tag)
}
}

private fun wrapTask(tag: String, task: Callable): Callable = Callable {
try {
task.call()
} finally {
taskMap.remove(tag)
}
}

class Builder {
var corePoolSize: Int = Runtime.getRuntime().availableProcessors()
var maxPoolSize: Int = corePoolSize * 2
var keepAliveTime: Long = 10L
var workQueue: BlockingQueue? = null
var threadFactory: ThreadFactory? = null
var allowCoreThreadTimeout: Boolean = false
var exceptionHandler: ((Thread, Throwable) -> Unit)? = null

 fun build() = ThreadPoolManager(this)

}

private class DefaultThreadFactory(
private val exceptionHandler: ((Thread, Throwable) -> Unit)?
) : ThreadFactory {
private val group: ThreadGroup
private val threadNumber = AtomicInteger(1)
private val namePrefix: String

 init {
     val s = System.getSecurityManager()
     group = s?.threadGroup ?: Thread.currentThread().threadGroup
     namePrefix = "pool-${System.identityHashCode(this)}-thread-"
 }

 override fun newThread(r: Runnable): Thread {
     return Thread(group, r, namePrefix + threadNumber.getAndIncrement()).apply {
         isDaemon = false
         priority = Thread.NORM_PRIORITY
         uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { t, e ->
             exceptionHandler?.invoke(t, e) ?: run {
                 Log.e("ThreadPool", "Uncaught exception in thread ${t.name}", e)
             }
         }
     }
 }

}
// endregion
}