本文理论上对 multiprocessing.dummy
的Pool同样有效。
python2.x中multiprocessing提供的基于函数进程池,按下 ctrl+c
不能停止所有的进程并退出。即必须 ctrl+z
后找到残留的子进程,把它们干掉。先看一段 ctrl+c
无效的代码:
#!/usr/bin/env pythonimport multiprocessingimport osimport timedef do_work(x): print 'Work Started: %s' % os.getpid() time.sleep(10) return x * xdef main(): pool = multiprocessing.Pool(4) try: result = pool.map_async(do_work, range(8)) pool.close() pool.join() print result except KeyboardInterrupt: print 'parent received control-c' pool.terminate() pool.join() if __name__ == "__main__": main()
这段代码运行后,按 ^c
一个进程也杀不掉,最后会残留包括主进程在内共5个进程(1+4),kill掉主进程能让其全部退出。很明显,使用进程池时 KeyboardInterrupt
不能被进程捕捉。解决方法有两种。
方案一
下面这段是python源码里multiprocessing下的 pool.py
中的一段,ApplyResult就是Pool用来保存函数运行结果的类
class ApplyResult(object): def __init__(self, cache, callback): self._cond = threading.Condition(threading.Lock()) self._job = job_counter.next() self._cache = cache self._ready = False self._callback = callback cache[self._job] = self
而下面这段代码也是 ^c
无效的代码
if __name__ == '__main__': import threading cond = threading.Condition(threading.Lock()) cond.acquire() cond.wait() print "done"
很明显, threading.Condition(threading.Lock())
对象无法接收 KeyboardInterrupt
,但稍微修改一下,给 cond.wait()
一个timeout参数即可,这个timeout可以在 map_async
后用get传递,把
result = pool.map_async(do_work, range(4))
改为
result = pool.map_async(do_work, range(4)).get(1)
就能成功接收 ^c
了, get
里面填1填99999还是0xffff都行
方案二
另一种方法当然就是自己写进程池了,需要使用队列,贴一段代码感受下
#!/usr/bin/env pythonimport multiprocessing, os, signal, time, Queuedef do_work(): print 'Work Started: %d' % os.getpid() time.sleep(2) return 'Success'def manual_function(job_queue, result_queue): signal.signal(signal.SIGINT, signal.SIG_IGN) while not job_queue.empty(): try: job = job_queue.get(block=False) result_queue.put(do_work()) except Queue.Empty: pass #except KeyboardInterrupt: passdef main(): job_queue = multiprocessing.Queue() result_queue = multiprocessing.Queue() for i in range(6): job_queue.put(None) workers = [] for i in range(3): tmp = multiprocessing.Process(target=manual_function, args=(job_queue, result_queue)) tmp.start() workers.append(tmp) try: for worker in workers: worker.join() except KeyboardInterrupt: print 'parent received ctrl-c' for worker in workers: worker.terminate() worker.join() while not result_queue.empty(): print result_queue.get(block=False)if __name__ == "__main__": main()
常见的错误方案
这个必须要提一下,我发现segmentfault上都有人被误导了
理论上,在Pool初始化时传递一个 initializer
函数,让子进程忽略 SIGINT
信号,也就是^c
,然后Pool进行 terminate
处理。代码
#!/usr/bin/env pythonimport multiprocessingimport osimport signalimport timedef init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)def run_worker(x): print "child: %s" % os.getpid() time.sleep(20) return x * xdef main(): pool = multiprocessing.Pool(4, init_worker) try: results = [] print "Starting jobs" for x in range(8): results.append(pool.apply_async(run_worker, args=(x,))) time.sleep(5) pool.close() pool.join() print [x.get() for x in results] except KeyboardInterrupt: print "Caught KeyboardInterrupt, terminating workers" pool.terminate() pool.join()if __name__ == "__main__": main()
然而这段代码只有在运行在 time.sleep(5)
处的时候才能用 ctrl+c
中断,即前5s你按 ^c
有效,一旦 pool.join()
后则完全无效!
建议
先确认是否真的需要用到多进程,如果是IO多的程序建议用多线程或协程,计算特别多则用多进程。如果非要用多进程,可以利用Python3的concurrent.futures包(python2.x也能装),编写更加简单易用的多线程/多进程代码,其使用和Java的concurrent框架有些相似