python3中線程池


1.在使用多線程處理任務時也不是線程越多越好,由於在切換線程的時候,需要切換上下文環境,依然會造成cpu的大量開銷。為解決這個問題,線程池的概念被提出來了。預先創建好一個較為優化的數量的線程,讓過來的任務立刻能夠使用,就形成了線程池。在python中,沒有內置的較好的線程池模塊,需要自己實現或使用第三方模塊。下面是一個簡單的線程池:

import threading,time,os,queue

class ThreadPool(object):
    def __init__(self,maxsize):
        self.maxsize = maxsize
        self._q = queue.Queue(self.maxsize)
        for i in range(self.maxsize):
            self._q.put(threading.Thread)

    def getThread(self):
        return self._q.get()

    def addThread(self):
        self._q.put(threading.Thread)

def fun(num,p):
    print('this is thread [%s]'%num)
    time.sleep(1)
    p.addThread()


if __name__ == '__main__':
    pool = ThreadPool(2)
    for i in range(103):
        t = pool.getThread()
        a = t(target = fun,args = (i,pool))
        a.start()

 上面的例子是把線程類當做元素添加到隊列內。實現方法比較糙,每個線程使用后就被拋棄,一開始就將線程開到滿,因此性能較差。下面是一個相對好一點的例子,在這個例子中,隊列里存放的不再是線程對象,而是任務對象,線程池也不是一開始就直接開辟所有線程,而是根據需要,逐步建立,直至池滿。通過詳細的代碼注釋,應該會有個清晰的理解。

2.高級版線程池

"""
一個基於thread和queue的線程池,以任務為隊列元素,動態創建線程,重復利用線程,
通過close和terminate方法關閉線程池。
"""
import queue
import threading
import contextlib
import time

# 創建空對象,用於停止線程
StopEvent = object()


def callback(status, result):
    """
    根據需要進行的回調函數,默認不執行。
    :param status: action函數的執行狀態
    :param result: action函數的返回值
    :return:
    """
    pass


def action(thread_name,arg):
    """
    真實的任務定義在這個函數里
    :param thread_name: 執行該方法的線程名
    :param arg: 該函數需要的參數
    :return:
    """
    # 模擬該函數執行了0.1秒
    time.sleep(0.1)
    print("第%s個任務調用了線程 %s,並打印了這條信息!" % (arg+1, thread_name))


class ThreadPool:

    def __init__(self, max_num, max_task_num=None):
        """
        初始化線程池
        :param max_num: 線程池最大線程數量
        :param max_task_num: 任務隊列長度
        """
        # 如果提供了最大任務數的參數,則將隊列的最大元素個數設置為這個值。
        if max_task_num:
            self.q = queue.Queue(max_task_num)
        # 默認隊列可接受無限多個的任務
        else:
            self.q = queue.Queue()
        # 設置線程池最多可實例化的線程數
        self.max_num = max_num
        # 任務取消標識
        self.cancel = False
        # 任務中斷標識
        self.terminal = False
        # 已實例化的線程列表
        self.generate_list = []
        # 處於空閑狀態的線程列表
        self.free_list = []

    def put(self, func, args, callback=None):
        """
        往任務隊列里放入一個任務
        :param func: 任務函數
        :param args: 任務函數所需參數
        :param callback: 任務執行失敗或成功后執行的回調函數,回調函數有兩個參數
        1、任務函數執行狀態;2、任務函數返回值(默認為None,即:不執行回調函數)
        :return: 如果線程池已經終止,則返回True否則None
        """
        # 先判斷標識,看看任務是否取消了
        if self.cancel:
            return
        # 如果沒有空閑的線程,並且已創建的線程的數量小於預定義的最大線程數,則創建新線程。
        if len(self.free_list) == 0 and len(self.generate_list) < self.max_num:
            self.generate_thread()
        # 構造任務參數元組,分別是調用的函數,該函數的參數,回調函數。
        w = (func, args, callback,)
        # 將任務放入隊列
        self.q.put(w)

    def generate_thread(self):
        """
        創建一個線程
        """
        # 每個線程都執行call方法
        t = threading.Thread(target=self.call)
        t.start()

    def call(self):
        """
        循環去獲取任務函數並執行任務函數。在正常情況下,每個線程都保存生存狀態,
        直到獲取線程終止的flag。
        """
        # 獲取當前線程的名字
        current_thread = threading.currentThread().getName()
        # 將當前線程的名字加入已實例化的線程列表中
        self.generate_list.append(current_thread)
        # 從任務隊列中獲取一個任務
        event = self.q.get()
        # 讓獲取的任務不是終止線程的標識對象時
        while event != StopEvent:
            # 解析任務中封裝的三個參數
            func, arguments, callback = event
            # 抓取異常,防止線程因為異常退出
            try:
                # 正常執行任務函數
                result = func(current_thread, *arguments)
                success = True
            except Exception as e:
                # 當任務執行過程中彈出異常
                result = None
                success = False
            # 如果有指定的回調函數
            if callback is not None:
                # 執行回調函數,並抓取異常
                try:
                    callback(success, result)
                except Exception as e:
                    pass
            # 當某個線程正常執行完一個任務時,先執行worker_state方法
            with self.worker_state(self.free_list, current_thread):
                # 如果強制關閉線程的flag開啟,則傳入一個StopEvent元素
                if self.terminal:
                    event = StopEvent
                # 否則獲取一個正常的任務,並回調worker_state方法的yield語句
                else:
                    # 從這里開始又是一個正常的任務循環
                    event = self.q.get()
        else:
            # 一旦發現任務是個終止線程的標識元素,將線程從已創建線程列表中刪除
            self.generate_list.remove(current_thread)

    def close(self):
        """
        執行完所有的任務后,讓所有線程都停止的方法
        """
        # 設置flag
        self.cancel = True
        # 計算已創建線程列表中線程的個數,然后往任務隊列里推送相同數量的終止線程的標識元素
        full_size = len(self.generate_list)
        while full_size:
            self.q.put(StopEvent)
            full_size -= 1

    def terminate(self):
        """
        在任務執行過程中,終止線程,提前退出。
        """
        self.terminal = True
        # 強制性的停止線程
        while self.generate_list:
            self.q.put(StopEvent)

    # 該裝飾器用於上下文管理
    @contextlib.contextmanager
    def worker_state(self, state_list, worker_thread):
        """
        用於記錄空閑的線程,或從空閑列表中取出線程處理任務
        """
        # 將當前線程,添加到空閑線程列表中
        state_list.append(worker_thread)
        # 捕獲異常
        try:
            # 在此等待
            yield
        finally:
            # 將線程從空閑列表中移除
            state_list.remove(worker_thread)

# 調用方式
if __name__ == '__main__':
    # 創建一個最多包含5個線程的線程池
    pool = ThreadPool(5)
    # 創建100個任務,讓線程池進行處理
    for i in range(100):
        pool.put(action, (i,), callback)
    # 等待一定時間,讓線程執行任務
    time.sleep(3)
    print("-" * 50)
    print("\033[32;0m任務停止之前線程池中有%s個線程,空閑的線程有%s個!\033[0m"
          % (len(pool.generate_list), len(pool.free_list)))
    # 正常關閉線程池
    pool.close()
    print("任務執行完畢,正常退出!")
    # 強制關閉線程池
    # pool.terminate()
    # print("強制停止任務!")

 

3.利用簡單線程池和paramiko實現對遠程服務器的訪問獲取到相關信息:(自己寫的例子,比較low)

import paramiko,threading
import queue

class ThreadPool(object):
    def __init__(self,maxsize):
        self.maxsize = maxsize
        self._q = queue.Queue(self.maxsize)
        for i in range(self.maxsize):
            self._q.put(threading.Thread)

    def getThread(self):
        return self._q.get()

    def addThread(self):
        self._q.put(threading.Thread)

def ssh_fun(ip,user,password,pool):
    try:
        ssh = paramiko.SSHClient()
        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        ssh.connect(ip, 22, user, password)
        stdin, stdout, stderr = ssh.exec_command('hostname')
        info = stdout.read().decode().strip()
        print('IP:%s  hostname:%s'%(ip,info))
        ssh.close()
    except Exception:
        print('sorry I can`t connect this server [%s]'%ip)
    pool.addThread()

if __name__ == '__main__':
    t_list = []
    pool = ThreadPool(2)
    with open('aaa','r+',encoding='utf-8') as f:
        for line in f:
            split = line.split()
            ip,user,password = split[0],split[1],split[2]
            th = pool.getThread()
            t = th(target=ssh_fun,args=(ip,user,password,pool))
            t.start()
            t_list.append(t)
    for i in t_list:
        i.join()

 在這里我為了測試線程池中只有兩個線程,並且我這個是讀取aaa文件的,這個文件中包含用戶名和密碼等相關信息,樣子如下(其實可以把這些放進數據庫中,使用python從數據庫中進行讀取):

192.168.32.167  root    111111
192.168.32.110  root    111111
192.168.32.120  root    111111
192.168.32.150  root    111111

 而最后執行的效果如下:

IP:192.168.32.167  hostname:ns.root
sorry I can`t connect this server [192.168.32.110]
IP:192.168.32.150  hostname:localhost.localdomain
sorry I can`t connect this server [192.168.32.120]

4.改進版結合了mysql向數據庫中插入數據:

import paramiko,threading
import queue
import pymysql

class ThreadPool(object):
    def __init__(self,maxsize):
        self.maxsize = maxsize
        self._q = queue.Queue(self.maxsize)
        for i in range(self.maxsize):
            self._q.put(threading.Thread)

    def getThread(self):
        return self._q.get()

    def addThread(self):
        self._q.put(threading.Thread)

def ssh_fun(ip,user,password,pool,db):
    cursor = db.cursor()
    try:
        ssh = paramiko.SSHClient()
        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        ssh.connect(ip, 22, user, password)
        stdin, stdout, stderr = ssh.exec_command('hostname')
        info = stdout.read().decode().strip()
        print('IP:%s  hostname:%s'%(ip,info))
        try:
            cursor.execute('insert into server_status(ip,password,hostname) values ("%s","%s","%s")' %(ip,password,info))
            db.commit()
        except:
            db.rollback()
        ssh.close()
    except Exception:
        print('sorry I can`t connect this server [%s]'%ip)
    pool.addThread()

if __name__ == '__main__':
    t_list = []
    pool = ThreadPool(3)
    db = pymysql.connect('192.168.32.188', 'hjc', '111111', 'hjc')
    with open('aaa','r+',encoding='utf-8') as f:
        for line in f:
            split = line.split()
            ip,user,password = split[0],split[1],split[2]
            th = pool.getThread()
            t = th(target=ssh_fun,args=(ip,user,password,pool,db))
            t.start()
            t_list.append(t)
    for i in t_list:
        i.join()
    db.close()

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM