130行C語言實現個用戶態線程庫(1)


准確的說是除掉頭文件,測試代碼和非關鍵的純算法代碼(只有雙向環形鏈表的ADT),核心代碼只有130行左右,已經是蠅量級的用戶態線程庫了。把這個庫取名為ezthread,意思是,這太easy了,人人都可以讀懂並且實現這個用戶態線程庫。我把該項目放在github上,歡迎來拍磚: https://github.com/Yuandong-Chen/coroutine/tree/old-version(注意,最新的版本已經用了共享棧技術,能夠支持1000K數量級的協程了,讀完這篇博文后可以進一步參考后續的博文:http://www.cnblogs.com/github-Yuandong-Chen/p/6973932.html)。那么下面談談怎么實現這個ezthread。

大家都會雙向環形鏈表(就是頭尾相連的雙向鏈表),我們構造這個ADT結構:

首先是每個節點:

1 typedef struct __pnode pNode;
2 struct __pnode
3 {
4     pNode *next;
5     pNode *prev;
6     Thread_t *data;
7 };

顯然,next指向下一個節點,prev指向上一個節點,data指向該節點數據,那么這個Thread_t是什么類型的數據結構呢?

typedef struct __ez_thread Thread_t;
struct __ez_thread
{
    Regs regs;
    int tid;
    unsigned int stacktop;
    unsigned int stacksize;
    void *stack;
    void *retval;
};

這個結構體包含了線程內部的信息,比如第一項為Regs,記錄的是各個寄存器的取值(我們在下面給出具體的結構),tid就是線程的ID了,stacktop記錄的是線程棧的頂部(和頁對齊的最大地址,每個線程都有自己的運行時的棧,用於構成他們相對獨立的運行時環境),stacksize就是棧的大小了,stack指針指向我們給該線程棧分配的堆的指針(什么?怎么一會棧一會堆的?我們其實用了malloc函數分配出一些堆空間,把這些空間用於線程棧,當線程退出時候,我們再free這些堆),retval就是線程運行完了的返回值(pthread_join里頭拿到的線程返回值就是這個了)。

下面是寄存器結構體:

typedef struct __thread_table_regs Regs;
struct __thread_table_regs
{
    int _ebp;
    int _esp;
    int _eip;
    int _eflags;
};

真是好懂,一看就知道了,這個結構體只能支持X86體系的計算機了。那么還有個問題,為何只有這些寄存器,沒用其他的比如:eax,ebx,edi,esi等等呢?因為我們在轉換狀態函數switch_to里頭當返回時(准確地說是從上次切換的點切換回來時)用了return來切換回線程運行時環境,return會自動幫助我們把這些其他的寄存器的值恢復原狀的(具體我們放到switch_to的時候再詳細說明)。

然后呢,我們定義了一個游標去取這個環形鏈表的值,否則我們怎么讀取這個環形鏈表里頭的數據呢?總得有個東西指向其中某個節點吧。

typedef struct __loopcursor Cursor;
struct __loopcursor
{
    int total;
    pNode *current;
};

這個游標結構體記錄了現在指向的節點地址和這個環形鏈表里頭一共有多少節點。

我們得用兩個這樣的環形鏈表結構體來支持我們的線程庫,為何是倆呢?一個是正在運行的線程,我們把他們串成一個環形鏈表,取名為live(活的),然后用另外一個鏈表把運行結束的線程串成一串,取名為dead(死的)。然后最開始我們就有個線程在運行了,那就是主線程main,我們用pmain節點來記錄主線程:

extern Cursor live;
extern Cursor dead;
extern Thread_t pmain;

好了,剩下的只有在這些結構體上操作的函數了:

void init();
void switch_to(int, ...);
int threadCreat(Thread_t **, void *(*)(void *), void *);
int threadJoin(Thread_t *, void **);

我們開始時調用init,以初始化我們的live,dead和pmain。然后當我們想創造線程時,就threadCreat就可以了,用法和pthread_create基本一模一樣,熟悉posix多線程的人一看就明白了,threadJoin也是仿照pthread_join接口寫的。這里的switch_to就是最關鍵的運行時環境轉換函數了,當線程調用這個函數時候,我們就切換到其他線程上次暫停的點去執行了(這些狀態都保存在我們的Thread_t結構體里,所以我們能夠記錄下切換前的狀態,從而能夠從容地去切換到下一個線程中)。我們沒有用定時器每隔幾微秒去激發switch_to(實現起來也是非常簡單的,但是得添加多個signal_block函數,非常不簡潔),而是讓線程里頭的函數主動調用switch_to來切換線程,這有點類似協程。

好了,現在講具體的實現了。首先是對雙向鏈表的操作函數,這個東西不是我們的重點,懂基礎算法數據結構的人都能實現,具體是雙向環形鏈表的增查刪操作:

 1 void initCursor(Cursor *cur)
 2 {
 3     cur->total = 0;
 4     cur->current = NULL;
 5 }
 6 
 7 Thread_t *findThread(Cursor *cur, int tid)
 8 {
 9     int counter = cur->total;
10     if(counter == 0){
11         return NULL;
12     }
13 
14     int i;
15     pNode *tmp = cur->current;
16     for (int i = 0; i < counter; ++i)
17     {
18         if((tmp->data)->tid == tid){
19             return tmp->data;
20         }
21 
22         tmp = tmp->next;
23     }
24     return NULL;
25 }
26 
27 int appendThread(Cursor *cur, Thread_t *pth)
28 {
29     if(cur->total == 0)
30     {
31         cur->current = (pNode *)malloc(sizeof(pNode));
32         assert(cur->current);
33         (cur->current)->data = pth;
34         (cur->current)->prev = cur->current;
35         (cur->current)->next = cur->current;
36         cur->total++;
37         return 0;
38     }
39     else
40     {    
41         if(cur->total > MAXCOROUTINES)
42         {
43             assert((cur->total == MAXCOROUTINES));
44             return -1;
45         }
46         
47         pNode *tmp = malloc(sizeof(pNode));
48         assert(tmp);
49         tmp->data = pth;
50         tmp->prev = cur->current;
51         tmp->next = (cur->current)->next;
52         ((cur->current)->next)->prev = tmp;
53         (cur->current)->next = tmp;
54         cur->total++;
55         return 0;
56     }
57 }
58 
59 pNode *deleteThread(Cursor *cur, int tid)
60 {
61     int counter = cur->total;
62     int i;
63     pNode *tmp = cur->current;
64     for (int i = 0; i < counter; ++i)
65     {
66         if((tmp->data)->tid == tid){
67             (tmp->prev)->next = tmp->next;
68             (tmp->next)->prev = tmp->prev;
69             if(tmp == cur->current)
70             {
71                 cur->current = cur->current->next;
72             }  
73 
74             cur->total--;
75             assert(cur->total >= 0);
76             return tmp;
77         }
78         tmp = tmp->next;
79     }
80     return NULL;
81 }
雙向鏈表操作函數

拋開這部分純算法代碼,我們只剩下130行代碼了。這還不如某些函數的代碼量大。但是我們就是在這130行代碼里頭實現了switch_to,threadCreat以及threadJoin等等關鍵代碼。

先說下init怎么實現的:

1 void init()
2 {
3     initCursor(&live);
4     initCursor(&dead);
5     appendThread(&live, &pmain);
6 }

其實關鍵點只有一句,那就是第5行的append(&live,&pmain);往live鏈表里頭添加pmain節點,但是我們的pmain還沒初始化呢,里頭stack,regs等等通通都是0,但是沒事呢,因為當我們第一次進入switch_to的時候,switch_to在跳轉前會幫助我們保存當前線程,這時也就是pmain的運行時狀態。

然后我們看看threadCreat怎么實現:

 1 int threadCreat(Thread_t **pth, void *(*start_rtn)(void *), void *arg)
 2 {
 3 
 4     *pth = malloc(sizeof(Thread_t));
 5     (*pth)->stack = malloc(PTHREAD_STACK_MIN);
 6     assert((*pth)->stack);
 7     (*pth)->stacktop = (((int)(*pth)->stack + PTHREAD_STACK_MIN)&(0xfffff000));
 8     (*pth)->stacksize = PTHREAD_STACK_MIN - (((int)(*pth)->stack + PTHREAD_STACK_MIN) - (*pth)->stacktop);
 9     (*pth)->tid = fetchTID();
10     /* set params */
11     void *dest = (*pth)->stacktop - 12;
12     memcpy(dest, pth, 4);
13     dest += 4;
14     memcpy(dest, &start_rtn, 4);
15     dest += 4;
16     memcpy(dest, &arg, 4);
17     (*pth)->regs._eip = &real_entry;
18     (*pth)->regs._esp = (*pth)->stacktop - 16;
19     (*pth)->regs._ebp = 0;
20     appendThread(&live, (*pth));
21 
22     return 0;
23 }

我們在第4行分配了堆空間,然后讓線程棧頂變量stacktop對齊頁,設置stacksize大小(這個其實對我們的線程庫沒有用,因為我們還沒有實現類似stackguard之類的檢查機制),設置tid,這里fetchTID函數如下:

1 int fetchTID()
2 {
3     static int tid;
4     return ++tid;
5 }

接着,我們在threadCreat函數的11-16行代碼中,在棧頂壓入變量pth,start_rtn以及arg(我們用memcpy來操作線程棧空間),這些都是作為real_entry這個函數的參數壓入線程棧的。我們不難發現,其實每個線程的最初入口地址都是real_entry函數(注意到我們在17行把eip設置為real_entry的地址)。最后,我們於17-19行設置寄存器變量,以滿足剛進入該real_entry時的棧的狀態,在live鏈表中添加該線程結構體指針,返回。這一系列操作導致的效果就是,比如我們第一次調用threadCreat函數,當發生switch_to的時候,當然我們先保存當前線程狀態,然后就從主線程main中切換到了real_entry里頭去了,而且對應的參數我們設置好了,就好像我們在主線程里頭直接調用了real_entry一樣。下面看下real_entry做了些什么:

 1 void real_entry(Thread_t *pth, void *(*start_rtn)(void *), void* args)
 2 {
 3     ALIGN();
 4 
 5     pth->retval = (*start_rtn)(args);
 6 
 7     deleteThread(&live, pth->tid);
 8     appendThread(&dead, pth);
 9 
10     switch_to(-1);
11 }

 

 第3行是對齊棧操作,我們先不做說明。接下來就是調用start_rtn函數,並且把args作為參數,返回值賦給線程的retval。當返回時,說明線程已經運行結束,在live鏈表里頭刪除該節點,在dead鏈表里頭添加該節點。在第10行最后調用switch_to(-1),也就是在switch_to里頭直接跳到下一個線程去執行,且不保存當前狀態。

我們再看下threadJoin函數的實現:

 1 int threadJoin(Thread_t *pth, void **rval_ptr)
 2 {
 3 
 4     Thread_t *find1, *find2;
 5     find1 = findThread(&live, pth->tid);
 6     find2 = findThread(&dead, pth->tid);
 7     
 8 
 9     if((find1 == NULL)&&(find2 == NULL)){
10         
11         return -1;
12     }
13 
14     if(find2){
15         if(rval_ptr != NULL)
16             *rval_ptr = find2->retval;
17 
18         pNode *tmp = deleteThread(&dead, pth->tid);
19         free(tmp);
20         free((Stack_t)find2->stack);
21         free(find2);
22         return 0;
23     }
24 
25     while(1)
26     {
27         switch_to(0);
28         if((find2 = findThread(&dead, pth->tid))!= NULL){
29             if(rval_ptr!= NULL)
30                 *rval_ptr = find2->retval;
31 
32             pNode *tmp = deleteThread(&dead, pth->tid);
33             free(tmp);
34             free((Stack_t)find2->stack);
35             free(find2);
36             return 0;
37         }   
38     }
39     return -1;
40 }

threadJoin是用於回收線程資源並得到返回值的。實現大體的思路就是,我們先查找live和dead里頭有沒有這個線程,如果都沒有,說明根本不存在這個線程,如果dead鏈表里頭有,那么我們就得到返回值(15-16行),然后釋放堆空間(19-22行)。如果在live里頭,說明該線程還沒執行結束,我們進入循環,先調用switch_to(0),保存當前線程狀態,然后切換到下一個線程去。當再次回到這個循環時候,我們繼續看看dead里頭有沒有這個線程,有就設置返回值(29-30行),然后釋放資源(32-35行),否則繼續切換並循環。

最后,最關鍵的,我們給出switch_to的實現:

 1 void switch_to(int signo, ...)
 2 {
 3 
 4     va_list ap; 
 5     va_start(ap, signo);
 6 
 7     Regs regs;
 8 
 9     if(signo == -1)
10     {
11         regs = live.current->data->regs;
12         JMP(regs);
13         assert(0);
14     }
15     
16     int _ebp;
17     int _esp;
18     int _eip = &&_REENTERPOINT;
19     int _eflags;
20     /* save current context */
21     SAVE();
22     /* save context in current thread */
23     live.current->data->regs._eip = _eip;
24     live.current->data->regs._esp = _esp;
25     live.current->data->regs._ebp = _ebp;
26     live.current->data->regs._eflags = _eflags;
27 
28     if(va_arg(ap,int) == -1){
29  _REENTERPOINT:
30         assert(va_arg(ap,int) != -1);
31         return;
32     }
33 
34     va_end(ap);
35     regs = live.current->next->data->regs;
36     live.current = live.current->next;
37     JMP(regs);
38     assert(0);
39 }

先看11-13行,我們把自動變量regs的值賦為當前線程的寄存器的結構體,然后跳轉到當前線程(第12行JMP是跳轉語句,13行永遠不會執行)。這里大家有個疑問,從當前線程跳轉到當前線程,那么還不是當前線程么?然后執行assert(0)報錯退出?!其實只有當線程返回時,也就是在real_entry里頭才可能執行switch_to(-1),注意到real_entry最后的幾行代碼,里頭已經把當前線程從live里頭刪除,並添加到dead里了,所以現在live里頭的當前線程其實是下一個線程。然后我們看21-26行,我們保存當前寄存器的值到當前線程中,注意第18行,我們把返回點設置在了_REENTERPOINT這個標簽上,也就是以后如果再次切換到該線程時,我們會在第30行繼續向下執行,很簡單,第30行的有意義的代碼只有return,也就是恢復其他寄存器(eax,edi,esi等等),然后返回到線程繼續執行。我們繼續看34-38行代碼:我們把自動變量regs的值賦值為下一個線程的寄存器,然后live的當前線程指針current也指向了下一個線程,通過37行JMP,我們調到了下一個線程去執行,下個一個線程可能是real_entry處開始執行,也可能是_REENTERPOINT處開始執行。最后再從新說說31行的return到底return到哪里去了,我們看一下測試代碼:

 1 #include "ezthread.h"
 2 #include <stdio.h>
 3 #include <stdlib.h>
 4 
 5 void *sum1tod(void *d)
 6 {
 7     int i, j=0;
 8 
 9     for (i = 0; i <= d; ++i)
10     {
11         j += i;
12         printf("thread %d is grunting... %d\n",live.current->data->tid , i);
13         switch_to(0); // Give up control to next thread
14     }
15     
16     return ((void *)j);
17 }
18 
19 int main(int argc, char const *argv[])
20 {
21     int res = 0;
22     int i;
23     init();
24     Thread_t *tid1, *tid2;
25     int *res1, *res2;
26 
27     threadCreat(&tid1, sum1tod, 10);
28     threadCreat(&tid2, sum1tod, 10);
29 
30     for (i = 0; i <= 5; ++i){
31         res+=i;
32         printf("main is grunting... %d\n", i); 
33         switch_to(0); //Give up control to next thread
34     }
35     threadJoin(tid1, &res1); //Collect and Release the resourse of tid1
36     threadJoin(tid2, &res2); //Collect and Release the resourse of tid2
37     printf("parallel compute: %d = (1+2+3+4+5) + (1+2+...+10)*2\n", (int)res1+(int)res2+(int)res);
38     return 0;
39 }

注意到我們在測試代碼里頭sum1tod里頭調用了switch_to(0),如果這個循環加法(11-13行)還未結束,那么上述的那個_REENTERPOINT里頭的return就會return回這個循環繼續執行,就如在sum1tod里的switch_to(0)函數直接調用return,什么事情也沒干一樣,但是其實我們經過了無數其他線程的執行,但是在sum1tod里頭毫無感覺,簡直好像其他線程不存在一樣(除非我們在這里頭調用threadJoin等待其他線程結束)。

現在我們給出討厭的內嵌匯編:

 1 #define JMP(r)    asm volatile \
 2                 (   \
 3                     "pushl %3\n\t" \
 4                     "popf\n\t" \
 5                     "movl %0, %%esp\n\t" \
 6                     "movl %2, %%ebp\n\t" \
 7                     "jmp *%1\n\t" \
 8                     : \
 9                     : "m"(r._esp),"a"(r._eip),"m"(r._ebp), "m"(r._eflags) \
10                     :  \
11                 )
12 
13 #define SAVE()                  asm volatile \
14                             (  \
15                                    "movl %%esp, %0\n\t" \
16                                 "movl %%ebp, %1\n\t" \
17                                 "pushf\n\t" \
18                                 "movl (%%esp), %%eax\n\t" \
19                                 "movl %%eax, %2\n\t" \
20                                 "popf\n\t" \
21                                 : "=m"(_esp),"=m"(_ebp), "=m"(_eflags) \ 
22                                 : \
23                                 :  \
24                             )
25 
26 #define ALIGN()             asm volatile \
27                             ( \
28                                 "andl $-16, %%esp\n\t" \
29                                 : \
30                                 : \
31                                 :"%esp" \
32                             )
inline asm

第一個就是起到跳轉作用,第二個是保存寄存器到自動變量作用,最后一個是棧對齊作用。為何要棧對齊?因為我們在堆里頭設置了這個棧的空間,這個和普通的棧空間並不完全一樣,我們需要做對齊處理。

到這里我們就幾乎完全明白了這個線程庫的實現,還有一小點就是switch_to里頭的可變參數怎么回事,其實那個是防止編譯器中消除冗余代碼造成我們_REENTERPOINT中的代碼被優化而整個刪除用的。如果我們在_REENTERPOINT前加入goto語句跳到下面執行,然后刪除這個_REENTERPOINT之前的判斷語句,我們會發現,編譯器會把switch_to里頭的第28-32行作為冗余代碼全部刪除。

謝謝你能看到最后,告訴你們一個消息,其實我們的實現是介於longjmp和匯編實現版本之間的某種實現:我們用匯編保存了運行時狀態,但是其中的return又有點類似longjmp中自動恢復寄存器的作用。而且我們的庫比純匯編實現更具可移植性,但比longjmp實現版本又弱了點。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM