gomock是Go官方提供的測試框架,可以使用它對代碼中的那些接口類型進行mock,方便編寫單元測試。
安裝mockgen
go install github.com/golang/mock/mockgen@v1.6.0
構建mock
為數據庫函數編寫單元測試代碼,可我們又不能在單元測試過程中連接真實的數據庫,這個時候就需要mock DB這個接口來方便進行單元測試。
使用mockgen 工具來為生成相應的mock代碼。通過執行下面的命令,我們就能在當前項目下生成一個mock文件夾,里面存放了一個mock_employees.go文件。
mockgen -source=employees.go -destination=mock/mock_employees.go -package=mock
mock_employees.go文件中的內容就是mock相關接口的代碼了。
我們通常不需要編輯它,只需要在單元測試中按照規定的方式使用它們就可以了。
-source:包含要mock的接口的文件。
-destination:生成的源代碼寫入的文件。如果不設置此項,代碼將打印到標准輸出。
-package:用於生成的模擬類源代碼的包名。如果不設置此項包名默認在原包名前添加mock_前綴。
-imports:在生成的源代碼中使用的顯式導入列表。值為foo=bar/baz形式的逗號分隔的元素列表,其中bar/baz是要導入的包,foo是要在生成的源代碼中用於包的標識符。
mock文件 無需編輯
// Code generated by MockGen. DO NOT EDIT.
// Source: employees.go
// Package mock is a generated GoMock package.
package mock
import (
reflect "reflect"
domain "server/domain"
gomock "github.com/golang/mock/gomock"
)
// MockEmployeesRepository is a mock of EmployeesRepository interface.
type MockEmployeesRepository struct {
ctrl *gomock.Controller
recorder *MockEmployeesRepositoryMockRecorder
}
// MockEmployeesRepositoryMockRecorder is the mock recorder for MockEmployeesRepository.
type MockEmployeesRepositoryMockRecorder struct {
mock *MockEmployeesRepository
}
// NewMockEmployeesRepository creates a new mock instance.
func NewMockEmployeesRepository(ctrl *gomock.Controller) *MockEmployeesRepository {
mock := &MockEmployeesRepository{ctrl: ctrl}
mock.recorder = &MockEmployeesRepositoryMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockEmployeesRepository) EXPECT() *MockEmployeesRepositoryMockRecorder {
return m.recorder
}
// CreateEmployee mocks base method.
func (m *MockEmployeesRepository) CreateEmployee(arg0 *domain.Employees) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateEmployee", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// CreateEmployee indicates an expected call of CreateEmployee.
func (mr *MockEmployeesRepositoryMockRecorder) CreateEmployee(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateEmployee", reflect.TypeOf((*MockEmployeesRepository)(nil).CreateEmployee), arg0)
}
函數
func (repo *EmployeesRepository) CreateEmployee(employee *domain.Employees) (err error) {
err = repo.db.Debug().Create(&employee).Error
if err != nil {
err = fmt.Errorf("[repository.r.CreateEmployee] failed: employee = %+v, error = %w ", employee, err)
return
}
return
}
測試用例
package impl
import (
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/magiconair/properties/assert"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"server/domain"
)
func getDBMock() (*gorm.DB, sqlmock.Sqlmock, error) {
db, mock, err := sqlmock.New()
if err != nil {
return nil, nil, err
}
//defer db.Close()
gdb, err := gorm.Open(postgres.New(postgres.Config{
DriverName: "postgres",
PreferSimpleProtocol: true,
Conn: db,
}), &gorm.Config{})
if err != nil {
return nil, nil, err
}
return gdb, mock, nil
}
func TestEmployeesRepository_CreateEmployees(t *testing.T) {
type fields struct {
db *gorm.DB
}
type args struct {
employees *domain.Employees
}
db, mock, err := getDBMock()
assert.Equal(t, err, nil)
tests := []struct {
name string
fields fields
args args
invoke func(args)
wantErr bool
}{
{
name: "create Employees successful",
fields: fields{
db: db,
},
args: args{employees: &domain.Employees{
ID:0,
Code: "1",
Name:"zhangsan",
DepartmentID:1,
}},
invoke: func(args args) {
mock.ExpectQuery("INSERT INTO (.+)").WithArgs(
args.employees.Code, args.employees.Name,args.employees.DepartmentID,sqlmock.AnyArg(),sqlmock.AnyArg(),sqlmock.AnyArg()).
WillReturnRows(sqlmock.NewRows([]string{"ID"}).AddRow(1))
},
},
{
name: "create Employees failed",
fields: fields{
db: db,
},
args: args{employees: &domain.Employees{
Code: "1",
Name:"zhangsan",
DepartmentID:2,
}},
invoke: func(args args) {
mock.ExpectQuery("INSERT INTO (.+)").WithArgs(
args.employees.Code,args.employees.Name,args.employees.DepartmentID,sqlmock.AnyArg(),sqlmock.AnyArg(),sqlmock.AnyArg()).
WillReturnError(gorm.ErrInvalidData)
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &EmployeesRepository{
db: tt.fields.db,
}
tt.invoke(tt.args)
if err := c.CreateEmployee(tt.args.employees); (err != nil) != tt.wantErr {
t.Errorf("CreateEmployees() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}