代碼結構
.
├── client.go
├── coder.go
├── coder_test.go
├── rpc_test.go
├── server.go
├── session.go
└── session_test.go
代碼
client.go
package rpc
import (
"net"
"reflect"
)
// rpc 客戶端實現
// 抽象客戶端方法
type Client struct {
conn net.Conn
}
// client構造方法
func NewClient(conn net.Conn) *Client {
return &Client{conn: conn}
}
// 客戶端調用服務端rpc實現
// client.RpcCall("login", &req)
func (c *Client) RpcCall(name string, fpr interface{}) {
// 反射獲取函數原型
fn := reflect.ValueOf(fpr).Elem()
// 客戶端邏輯的實現
f := func(args []reflect.Value) (results []reflect.Value) {
// 從匿名函數中構建請求參數
inArgs := make([]interface{}, 0, len(args))
for _, v := range args {
inArgs = append(inArgs, v.Interface())
}
// 組裝rpc data請求數據
reqData := RpcData{Name: name, Args: inArgs}
// 進行數據編碼
reqByteData, err := encode(reqData)
if err != nil {
return
}
// 創建session 對象
session := NewSession(c.conn)
// 客戶端發送數據
err = session.Write(reqByteData)
if err != nil {
return
}
// 讀取客戶端數據
rspByteData, err := session.Read()
if err != nil {
return
}
// 數據進行解碼
rspData, err := decode(rspByteData)
if err != nil {
return
}
// 處理服務端返回的數據結果
outArgs := make([]reflect.Value, 0, len(rspData.Args))
for i, v := range rspData.Args {
// 數據特殊情況處理
if v == nil {
// reflect.Zero() 返回某類型的零值的value
// .Out()返回函數輸出的參數類型
// 得到具體第幾個位置的參數的零值
outArgs = append(outArgs, reflect.Zero(fn.Type().Out(i)))
continue
}
outArgs = append(outArgs, reflect.ValueOf(v))
}
return outArgs
}
// 函數原型到調用的關鍵,需要2個參數
// 參數1:函數原型,是Type類型
// 參數2:返回類型是Value類型
// 簡單理解:參數1是函數原型,參數2是客戶端邏輯
v := reflect.MakeFunc(fn.Type(), f)
fn.Set(v)
}
coder.go
package rpc
import (
"bytes"
"encoding/gob"
"fmt"
)
// 對傳輸的數據進行編解碼
// 使用Golang自帶的一個數據結構序列化編碼/解碼工具 gob
// 定義rpc數據交互式數據傳輸格式
type RpcData struct {
Name string // 調用方法名
Args []interface{} // 調用和返回的參數列表
}
// 編碼
func encode(data RpcData) ([]byte, error) {
// gob進行編碼
var buf bytes.Buffer
// 得到字節編碼器
encoder := gob.NewEncoder(&buf)
// 進行編碼
if err := encoder.Encode(data); err != nil {
fmt.Printf("gob encode failed, err: %v\n", err)
return nil, err
}
return buf.Bytes(), nil
}
// 解碼
func decode(data []byte) (RpcData, error) {
// 得到字節解碼器
buf := bytes.NewBuffer(data)
decoder := gob.NewDecoder(buf)
// 解碼數據
var rd RpcData
if err := decoder.Decode(&rd); err != nil {
fmt.Printf("gob decode failed, err: %v\n", err)
return rd, err
}
return rd, nil
}
server.go
package rpc
import (
"net"
"reflect"
)
// rpc 服務端實現
// 抽象服務端
type Server struct {
add string // 連接地址
funcs map[string]reflect.Value // 存儲方法名和方法的對應關系,服務注冊
}
// server 構造方法
func NewServer(addr string) *Server {
return &Server{add: addr, funcs: make(map[string]reflect.Value)}
}
// 注冊接口
func (s *Server) Register(name string, fc interface{}) {
if _, ok := s.funcs[name]; ok {
return
}
s.funcs[name] = reflect.ValueOf(fc)
}
func (s *Server) Run() (err error) {
listener, err := net.Listen("tcp", s.add)
if err != nil {
return
}
for {
// 監聽連接
conn, err := listener.Accept()
if err != nil {
conn.Close()
continue
}
// 創建會話
session := NewSession(conn)
// 讀取會話請求數據
reqData, err := session.Read()
if err != nil {
conn.Close()
continue
}
// 數據解碼
rpcReqData, err := decode(reqData)
// 獲取客戶端要調用的方法
fc, ok := s.funcs[rpcReqData.Name];
if !ok {
conn.Close()
continue
}
// 獲取請求的參數列表
args := make([]reflect.Value, 0, len(rpcReqData.Args))
for _, v := range rpcReqData.Args {
args = append(args, reflect.ValueOf(v))
}
// 調用
callReslut := fc.Call(args)
// 處理調用返回的數據結果
rargs := make([]interface{}, 0, len(callReslut))
for _, rv := range callReslut {
rargs = append(rargs, rv.Interface())
}
// 構建返回的rpc數據
rpcRspData := RpcData{Name: rpcReqData.Name, Args: rargs}
// 返回數據進行編碼
rspData, err := encode(rpcRspData)
if err != nil {
conn.Close()
continue
}
err = session.Write(rspData)
if err != nil {
conn.Close()
continue
}
}
return
}
session.go
package rpc
import (
"encoding/binary"
"fmt"
"io"
"net"
)
// 處理連接會話
// 會話對象結構體
type Session struct {
conn net.Conn
}
// 傳輸數據存儲方式
// 字節數組, 添加4個字節的頭,用來存儲數據的長度
// 會話構造函數
func NewSession(conn net.Conn) *Session {
return &Session{conn: conn}
}
// 從連接中讀取數據
func (s *Session) Read() (data []byte, err error) {
// 讀取數據header數據
header := make([]byte, 4)
_, err = s.conn.Read(header)
if err != nil {
fmt.Printf("read conn header data failed, err: %v\n", err)
return
}
// 讀取body數據
hlen := binary.BigEndian.Uint32(header)
data = make([]byte, hlen)
_, err = io.ReadFull(s.conn, data)
if err != nil {
fmt.Printf("read conn body data failed, err: %v\n", err)
return
}
return
}
// 向連接中寫入數據
func (s *Session) Write(data []byte) (err error) {
// 創建數據字節切片
buf := make([]byte, 4+len(data))
// 向header寫入數據長度
binary.BigEndian.PutUint32(buf[:4], uint32(len(data)))
// 寫入body內容
copy(buf[4:], data)
// 寫入連接數據
_, err = s.conn.Write(buf)
if err != nil {
fmt.Printf("write conn data failed, err: %v\n", err)
return
}
return
}
coder_test.go
package rpc
import (
"testing"
)
func TestCoder(t *testing.T) {
rd := RpcData{
Name: "login",
Args: []interface{}{"zhangsan", "zs123"},
}
eData, err := encode(rd)
if err != nil {
t.Error(err)
return
}
t.Logf("gob 編碼后數據長度: %d\n", len(eData))
dData, err := decode(eData)
if err != nil {
t.Error(err)
return
}
t.Logf("%#v\n", dData)
}
session_test.go
package rpc
import (
"net"
"sync"
"testing"
)
func TestSession(t *testing.T) {
addr := ":8080"
test_data := "my is test data"
var wg sync.WaitGroup
wg.Add(2)
// 寫數據
go func() {
defer wg.Done()
listener, err := net.Listen("tcp", addr)
if err != nil {
t.Fatal(err)
return
}
conn, _ := listener.Accept()
s := NewSession(conn)
data, err := s.Read()
if err != nil {
t.Error(err)
return
}
t.Log(string(data))
}()
// 讀數據
go func() {
defer wg.Done()
conn, err := net.Dial("tcp", addr)
if err != nil {
t.Fatal(err)
return
}
s := NewSession(conn)
err = s.Write([]byte(test_data))
if err != nil {
return
}
t.Log("寫入數據成功")
return
}()
wg.Wait()
}
rpc_test.go
package rpc
import (
"encoding/gob"
"fmt"
"net"
"testing"
)
// rpc 客戶端和服務端測試
// 定義一個服務端結構體
// 定義一個方法
// 通過調用rpc方法查詢用戶的信息
type User struct {
Name string
Age int
}
// 定義查詢用戶的方法
// 通過用戶id查詢用戶數據
func queryUser(id int) (User, error) {
// 造一些查詢user的假數據
users := make(map[int]User)
users[0] = User{"user01", 22}
users[1] = User{"user02", 23}
users[2] = User{"user03", 24}
if u, ok := users[id]; ok {
return u, nil
}
return User{}, fmt.Errorf("%d id not found", id)
}
func TestRpc(t *testing.T) {
// 給gob注冊類型
gob.Register(User{})
addr := ":8080"
// 創建服務端
server := NewServer(addr)
// 注冊服務
server.Register("queryUser", queryUser)
// 啟動服務端
go server.Run()
// 創建客戶端連接
conn, err := net.Dial("tcp", addr)
if err != nil {
return
}
// 創客戶端
client := NewClient(conn)
// 定義函數調用原型
var query func(int) (User, error)
// 客戶端調用rpc
client.RpcCall("queryUser", &query)
// 得到返回結果
user, err := query(1)
if err != nil {
t.Error(err)
return
}
fmt.Printf("%#v\n", user)
}