用Golang手寫一個RPC,理解RPC原理


代碼結構

.
├── client.go
├── coder.go
├── coder_test.go
├── rpc_test.go
├── server.go
├── session.go
└── session_test.go

代碼

client.go

package rpc

import (
	"net"
	"reflect"
)

// rpc 客戶端實現

// 抽象客戶端方法
type Client struct {
	conn net.Conn
}

// client構造方法
func NewClient(conn net.Conn) *Client {
	return &Client{conn: conn}
}

// 客戶端調用服務端rpc實現
// client.RpcCall("login", &req)
func (c *Client) RpcCall(name string, fpr interface{}) {
	// 反射獲取函數原型
	fn := reflect.ValueOf(fpr).Elem()
	// 客戶端邏輯的實現
	f := func(args []reflect.Value) (results []reflect.Value) {
		// 從匿名函數中構建請求參數
		inArgs := make([]interface{}, 0, len(args))
		for _, v := range args {
			inArgs = append(inArgs, v.Interface())
		}
		// 組裝rpc data請求數據
		reqData := RpcData{Name: name, Args: inArgs}
		// 進行數據編碼
		reqByteData, err := encode(reqData)
		if err != nil {
			return
		}
		// 創建session 對象
		session := NewSession(c.conn)
		// 客戶端發送數據
		err = session.Write(reqByteData)
		if err != nil {
			return
		}
		// 讀取客戶端數據
		rspByteData, err := session.Read()
		if err != nil {
			return
		}
		// 數據進行解碼
		rspData, err := decode(rspByteData)
		if err != nil {
			return
		}
		// 處理服務端返回的數據結果
		outArgs := make([]reflect.Value, 0, len(rspData.Args))
		for i, v := range rspData.Args {
			// 數據特殊情況處理
			if v == nil {
				// reflect.Zero() 返回某類型的零值的value
				// .Out()返回函數輸出的參數類型
				// 得到具體第幾個位置的參數的零值
				outArgs = append(outArgs, reflect.Zero(fn.Type().Out(i)))
				continue
			}
			outArgs = append(outArgs, reflect.ValueOf(v))
		}

		return outArgs
	}

	// 函數原型到調用的關鍵,需要2個參數
	// 參數1:函數原型,是Type類型
	// 參數2:返回類型是Value類型
	// 簡單理解:參數1是函數原型,參數2是客戶端邏輯
	v := reflect.MakeFunc(fn.Type(), f)
	fn.Set(v)
}

coder.go

package rpc

import (
	"bytes"
	"encoding/gob"
	"fmt"
)

// 對傳輸的數據進行編解碼
// 使用Golang自帶的一個數據結構序列化編碼/解碼工具 gob

// 定義rpc數據交互式數據傳輸格式
type RpcData struct {
	Name string        // 調用方法名
	Args []interface{} // 調用和返回的參數列表
}

// 編碼
func encode(data RpcData) ([]byte, error) {
	// gob進行編碼
	var buf bytes.Buffer
	// 得到字節編碼器
	encoder := gob.NewEncoder(&buf)
	// 進行編碼
	if err := encoder.Encode(data); err != nil {
		fmt.Printf("gob encode failed, err: %v\n", err)
		return nil, err
	}
	return buf.Bytes(), nil
}

// 解碼
func decode(data []byte) (RpcData, error) {
	// 得到字節解碼器
	buf := bytes.NewBuffer(data)
	decoder := gob.NewDecoder(buf)
	// 解碼數據
	var rd RpcData
	if err := decoder.Decode(&rd); err != nil {
		fmt.Printf("gob decode failed, err: %v\n", err)
		return rd, err
	}
	return rd, nil
}

server.go

package rpc

import (
	"net"
	"reflect"
)

// rpc 服務端實現

// 抽象服務端
type Server struct {
	add   string                   // 連接地址
	funcs map[string]reflect.Value // 存儲方法名和方法的對應關系,服務注冊
}

// server 構造方法
func NewServer(addr string) *Server {
	return &Server{add: addr, funcs: make(map[string]reflect.Value)}
}

// 注冊接口
func (s *Server) Register(name string, fc interface{}) {
	if _, ok := s.funcs[name]; ok {
		return
	}
	s.funcs[name] = reflect.ValueOf(fc)
}

func (s *Server) Run() (err error) {
	listener, err := net.Listen("tcp", s.add)
	if err != nil {
		return
	}
	for {
		// 監聽連接
		conn, err := listener.Accept()
		if err != nil {
			conn.Close()
			continue
		}
		// 創建會話
		session := NewSession(conn)
		// 讀取會話請求數據
		reqData, err := session.Read()
		if err != nil {
			conn.Close()
			continue
		}
		// 數據解碼
		rpcReqData, err := decode(reqData)
		// 獲取客戶端要調用的方法
		fc, ok := s.funcs[rpcReqData.Name];
		if !ok {
			conn.Close()
			continue
		}
		// 獲取請求的參數列表
		args := make([]reflect.Value, 0, len(rpcReqData.Args))
		for _, v := range rpcReqData.Args {
			args = append(args, reflect.ValueOf(v))
		}
		// 調用
		callReslut := fc.Call(args)
		// 處理調用返回的數據結果
		rargs := make([]interface{}, 0, len(callReslut))
		for _, rv := range callReslut {
			rargs = append(rargs, rv.Interface())
		}
		// 構建返回的rpc數據
		rpcRspData := RpcData{Name: rpcReqData.Name, Args: rargs}
		// 返回數據進行編碼
		rspData, err := encode(rpcRspData)
		if err != nil {
			conn.Close()
			continue
		}
		err = session.Write(rspData)
		if err != nil {
			conn.Close()
			continue
		}
	}
	return
}

session.go

package rpc

import (
	"encoding/binary"
	"fmt"
	"io"
	"net"
)

// 處理連接會話

// 會話對象結構體
type Session struct {
	conn net.Conn
}

// 傳輸數據存儲方式
// 字節數組, 添加4個字節的頭,用來存儲數據的長度

// 會話構造函數
func NewSession(conn net.Conn) *Session {
	return &Session{conn: conn}
}

// 從連接中讀取數據
func (s *Session) Read() (data []byte, err error) {
	// 讀取數據header數據
	header := make([]byte, 4)
	_, err = s.conn.Read(header)
	if err != nil {
		fmt.Printf("read conn header data failed, err: %v\n", err)
		return
	}
	// 讀取body數據
	hlen := binary.BigEndian.Uint32(header)
	data = make([]byte, hlen)
	_, err = io.ReadFull(s.conn, data)
	if err != nil {
		fmt.Printf("read conn body data failed, err: %v\n", err)
		return
	}
	return
}

// 向連接中寫入數據
func (s *Session) Write(data []byte) (err error) {
	// 創建數據字節切片
	buf := make([]byte, 4+len(data))
	// 向header寫入數據長度
	binary.BigEndian.PutUint32(buf[:4], uint32(len(data)))
	// 寫入body內容
	copy(buf[4:], data)
	// 寫入連接數據
	_, err = s.conn.Write(buf)
	if err != nil {
		fmt.Printf("write conn data failed, err: %v\n", err)
		return
	}
	return
}

coder_test.go

package rpc

import (
	"testing"
)

func TestCoder(t *testing.T) {
	rd := RpcData{
		Name: "login",
		Args: []interface{}{"zhangsan", "zs123"},
	}

	eData, err := encode(rd)
	if err != nil {
		t.Error(err)
		return
	}
	t.Logf("gob 編碼后數據長度: %d\n", len(eData))

	dData, err := decode(eData)
	if err != nil {
		t.Error(err)
		return
	}
	t.Logf("%#v\n", dData)
}

session_test.go

package rpc

import (
	"net"
	"sync"
	"testing"
)

func TestSession(t *testing.T) {
	addr := ":8080"
	test_data := "my is test data"
	var wg sync.WaitGroup
	wg.Add(2)
	// 寫數據
	go func() {
		defer wg.Done()
		listener, err := net.Listen("tcp", addr)
		if err != nil {
			t.Fatal(err)
			return
		}
		conn, _ := listener.Accept()
		s := NewSession(conn)
		data, err := s.Read()
		if err != nil {
			t.Error(err)
			return
		}
		t.Log(string(data))
	}()

	// 讀數據
	go func() {
		defer wg.Done()
		conn, err := net.Dial("tcp", addr)
		if err != nil {
			t.Fatal(err)
			return
		}
		s := NewSession(conn)
		err = s.Write([]byte(test_data))
		if err != nil {
			return
		}
		t.Log("寫入數據成功")
		return
	}()

	wg.Wait()
}

rpc_test.go

package rpc

import (
	"encoding/gob"
	"fmt"
	"net"
	"testing"
)

// rpc 客戶端和服務端測試

// 定義一個服務端結構體
// 定義一個方法
// 通過調用rpc方法查詢用戶的信息

type User struct {
	Name string
	Age  int
}

// 定義查詢用戶的方法
// 通過用戶id查詢用戶數據
func queryUser(id int) (User, error) {
	// 造一些查詢user的假數據
	users := make(map[int]User)
	users[0] = User{"user01", 22}
	users[1] = User{"user02", 23}
	users[2] = User{"user03", 24}
	if u, ok := users[id]; ok {
		return u, nil
	}
	return User{}, fmt.Errorf("%d id not found", id)

}

func TestRpc(t *testing.T) {
	// 給gob注冊類型
	gob.Register(User{})

	addr := ":8080"

	// 創建服務端
	server := NewServer(addr)
	// 注冊服務
	server.Register("queryUser", queryUser)
	// 啟動服務端
	go server.Run()

	// 創建客戶端連接
	conn, err := net.Dial("tcp", addr)
	if err != nil {
		return
	}
	// 創客戶端
	client := NewClient(conn)
	// 定義函數調用原型
	var query func(int) (User, error)
	// 客戶端調用rpc
	client.RpcCall("queryUser", &query)
	// 得到返回結果
	user, err := query(1)
	if err != nil {
		t.Error(err)
		return
	}
	fmt.Printf("%#v\n", user)
}


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM