GoLang 单元测试打桩和 mock

时间:2025-02-26 20:21:25

目录

什么是 mock

变量打桩

接口方法/Redis

函数/方法打桩

包函数

成员方法

MySQL

sqlmock

sqlite mock gorm

http mock


源码地址

单测基础

什么是 mock

       单元测试,顾名思义对某个单元函数进行测试,被测函数本身中用到的变量、函数、资源不应被测试代码依赖,所谓 mock,就是想办法通过 “虚拟” 代码替换掉依赖的方法和资源,一般需要 mock 掉以下依赖:

  • 变量

  • 函数/方法

  • MySQL

  • Redis

  • http 调用

变量打桩

有时我们的代码里依赖一个全局变量,测试方法根据全局变量的不同值执行不同的逻辑,那么可以用 gostub 对变量进行打桩。

 :

package main

var size = 5

func Size() int {
	if size > 10 {
		return 10
	}
	return size
}
package main

import (
    "testing"

    "/agiledragon/gomonkey/v2"
    "/prashantv/gostub"
)

func TestSizeStub(t *) {
    tests := []struct {
        name string
        want int
        f    func() *
    }{
        {name: "size > 10", want: 10, f: func() * {
            return (&size, 11)
        }},
        {name: "size <= 10", want: 3, f: func() * {
            return (&size, 3)
        }},
    }
    for _, tt := range tests {
        (, func(t *) {
            stub := ()
            if got := Size(); got !=  {
                ("Size() = %v, want %v", got, )
            }
            ()
        })
    }
}

func TestSizeMonkey(t *) {
    tests := []struct {
        name string
        want int
        f    func() *
    }{
        {name: "size > 10", want: 10, f: func() * {
            return (&size, 11)
        }},
        {name: "size <= 10", want: 3, f: func() * {
            return (&size, 3)
        }},
    }
    for _, tt := range tests {
        (, func(t *) {
            stub := ()
            if got := Size(); got !=  {
                ("Size() = %v, want %v", got, )
            }
            ()
        })
    }
}
$ go test -v -cover
=== RUN   TestSize
=== RUN   TestSize/size_>_10
=== RUN   TestSize/size_<=_10
--- PASS: TestSize (0.00s)
    --- PASS: TestSize/size_>_10 (0.00s)
    --- PASS: TestSize/size_<=_10 (0.00s)
PASS
coverage: 100.0% of statements

接口方法/Redis

首先 Go 语言推荐的是面向接口编程,所以官方提供并推荐使用  gomock  对依赖的方法进行 mock,前提是依赖的方法是通过抽象接口实现的,gomock 执行过程如下:

  1. 使用mockgen为你想要mock的接口生成一个mock。

  2. 在你的测试代码中,创建一个实例并把它作为参数传递给mock对象的构造函数来创建一个mock对象。

  3. 调用EXPECT()为你的mock对象设置各种期望和返回值。

  4. 调用mock控制器的Finish()以验证mock的期望行为。

gomock 常用方法:

类型

用法

作用

参数

(v)

匹配任何类型

(v)

匹配使用反射  与 v 相等的值

(v)

v 不是 Matcher 时,匹配使用反射  与 v 不相等的值;v 是 Matcher 时,匹配和 Macher 不匹配的值(Matcher

()

匹配等于 nil 的值

返回

Return()

mock 方法返回值

Do(func)

传入的 func 在 mock 真正被调用时自动执行,忽略 Return,比如:对调用方法的参数进行校验

DoAndReturn(func)

传入的 func 在 mock 真正被调用时自动执行,对应 func 返回值作为 mock 方法返回值

调用次数

AnyTimes(n int)

mock 方法可以被调用任意次数,一次不调用也不会失败(这里大家可以自检一下各自的单测代码,用这个方法的单测可能并没有按照预期运行

Times()

mock 方法被调用次数,次数不相等运行失败

MaxTimes(n int)

mock 方法被调用次数,大于规定次数运行失败

MinTimes(n int)

mock 方法被调用次数,小于规定次数运行失败

调用排序

(

().Return(),

().Return(),

().Return(),

)

规定多个 mock 方法的调用顺序,顺序不符运行失败

first := ().DoFucn()

second := ().DoFunc().After(first)

规定多个 mock 方法的先后依赖关系,顺序不符运行失败

首先通过 mockgen 生成 Redis Client 的 mock 代码:

$ go get -u /golang/mock/gomock
$ go install /golang/mock/mockgen

本地interface:
mockgen[go run -mod=mod /golang/mock/mockgen] -package mock -destination=mock/service/ -source=service/user/ -build_flags=-mod=mod  -mock_names=Service=MockUserService
远端interface:
go run -mod=mod /golang/mock/mockgen -package redis -destination ./mock/redis/  /go-redis/redis/v8 Cmdable

package main

import (
	"context"

	"/go-redis/redis/v8"
)

func handleRedis(c ) (string, error) {
	return ((), "redis").Result()
}

func conn() * {
	return (&{Addr: "127.0.0.1:6379"})
}
package main

import (
	"testing"

	"/go-redis/redis/v8"
	"/golang/mock/gomock"
)

func Test_handleRedis(t *) {
	ctl := (t)
	defer ()

	c := NewMockCmdable(ctl)
	().Get((), ()).Times(1).Return(("redis", nil))

	handleRedis(c)
}

函数/方法打桩

假如我们依赖的其他人写的方法,并不是通过接口实现的,无法使用 gomock 时,可以用 gomonkey 进行打桩

包函数

常用函数:

  • ():单个包函数打桩

  • ():连续多个包函数打桩

package main

func A() int {
	return B()
}

func AA() int {
	return B() + B()
}

func B() int {
	return 0
}
package main

import (
	"testing"

	"/agiledragon/gomonkey/v2"
	"/stretchr/testify/assert"
)

// TestA 函数,单次打桩
func TestA(t *) {
	patch := (B, func() int {
		return 1
	})
	defer ()

	(t, 1, A())
}

// TestAA 函数,连续打桩
func TestAA(t *) {
	patch := (B, []{
		{Values: {1}},
		{Values: {2}},
	})
	defer ()

	(t, 3, AA())
}

成员方法

常用函数:

  • ():单个公有成员方法打桩

  • ():单个私有成员方法打桩

  • ():连续多个公有成员方法打桩

  • ():连续多个私有成员方法打桩

package main

type S struct{}

func (s *S) A() int {
	return () + ()
}

func (s *S) AA() int {
	return () + () + () + ()
}

func (s *S) B() int {
	return 0
}

func (s *S) b() int {
	return 0
}
package main

import (
	"reflect"
	"testing"

	"/agiledragon/gomonkey/v2"
	"/stretchr/testify/assert"
)

// TestS_AA 成员方法单个打桩
func TestS_A(t *) {
	s := &S{}

	// 公共成员方法
	patch := ((s), "B", func(_ *S) int {
		return 1
	})
	// 私有成员方法
	((s), "b", func(_ *S) int {
		return 2
	})
	defer ()

	(t, 3, ())
}

// TestS_AA 成员方法连续打桩
func TestS_AA(t *) {
	s := &S{}

	// 私有成员方法
	patch := ((*S).b, []{
		{Values: {1}},
		{Values: {2}},
	})
	// 公共成员方法
	((s), "B", []{
		{Values: {1}},
		{Values: {2}},
	})
	defer ()

	(t, 6, ())
}

MySQL

sqlmock

package main

import (
	"database/sql"
	"encoding/json"
	"fmt"

	_ "/go-sql-driver/mysql"
	"/jmoiron/sqlx"
)

const dsn = "root:123456@tcp(127.0.0.1:3306)/test"

type Test struct {
	ID      int64  `json:"id" db:"id" gorm:"column:id"`
	GoodsID int64  `json:"goodsID" db:"goods_id" gorm:"column:goods_id"`
	Name    string `json:"name" db:"name" gorm:"column:name"`
}

func (Test) TableName() string {
	return "test"
}

func handle(db *) (err error) {
	tx, err := ()
	if err != nil {
		return
	}

	defer func() {
		switch err {
		case nil:
			err = ()
		default:
			()
		}
	}()

	rows, err := ("SELECT * from test where id > ?", 0)
	if err != nil {
		panic(err)
	}
	result := []Test{}
	if err = (rows, &result); err != nil {
		panic(err)
	}

	b, err := (result)
	if err != nil {
		panic(err)
	}
	("sql:", string(b))

	if _, err = ("UPDATE test SET goods_id = goods_id + 1 where id = 2"); err != nil {
		return
	}
	if _, err = ("INSERT INTO test (goods_id, name) VALUES (?, ?)", 1, "1"); err != nil {
		return
	}
	return
}

func main() {
	db, err := ("mysql", dsn)
	if err != nil {
		panic(err)
	}
	defer ()

	if err = handle(db); err != nil {
		panic(err)
	}
}
package main

import (
	"log"
	"os"
	"testing"
	"time"

	"/DATA-DOG/go-sqlmock"
	_ "/go-sql-driver/mysql"
	"/stretchr/testify/assert"
)

func Test_handle(t *) {
	db, mock, err := ()
	if err != nil {
		panic(err)
	}

	()
	// (.+) 用于替代字段,可用于 select、order、group等
	("SELECT (.+) from test where id > ?").WillReturnRows(([]string{"id", "goods_id", "name"}).AddRow(1, 1, "1"))
    // sql前缀匹配
	("UPDATE test SET goods_id").WillReturnResult((1, 1))
	("INSERT INTO test").WithArgs(1, "1").WillReturnResult((1, 1))
	()

	if err = handle(db); err != nil {
		panic(err)
	}

	if err = (); err != nil {
		panic(err)
	}
}

sqlite mock gorm

如果遇到如下错误:

/usr/local/go16/pkg/tool/linux_amd64/link: running gcc failed: exit status 1
/usr/bin/ld: /tmp/go-link-866330658/(.text+0x74): unresolvable H��@�>H��FH��H��H��@�~�F�H��@�~H��8�H��H��0�FH��H��(�FH��H�� �FH��H���FH��H���FH��H��F�fD relocation against symbol `stderr@@GLIBC_2.2.5'
/usr/bin/ld: BFD version 2.20.51.0.2-5.34.el6 20100205 internal error, aborting at  line 443 in bfd_get_reloc_size
/usr/bin/ld: Please report this bug.
collect2: ld returned 1 exit status

更新 go env gcc 版本:
go env -w CC=/opt/compiler/gcc-8.2/bin/gcc
go env -w CXX=/opt/compiler/gcc-8.2/bin/g++

或

CC=/opt/compiler/gcc-8.2/bin/gcc CXX=/opt/compiler/gcc-8.2/bin/g++ go test -c -cover 

package main

import (
	"database/sql"
	"encoding/json"
	"fmt"

	"/driver/mysql"
	"/gorm"
)

const dsn = "root:123456@tcp(127.0.0.1:3306)/test"

type Test struct {
	ID      int64  `json:"id" db:"id" gorm:"column:id"`
	GoodsID int64  `json:"goodsID" db:"goods_id" gorm:"column:goods_id"`
	Name    string `json:"name" db:"name" gorm:"column:name"`
}

func (Test) TableName() string {
	return "test"
}

func main() {
	orm, err := ((dsn))
	if err != nil {
		panic(err)
	}

	handleOrm(orm)
}

func handleOrm(orm *) {
	var rows []Test

	clause := func(db *) * {
		return ("id >= ?", 1)
	}
	err := clause(("*")).Find(&rows).Error
	if err != nil {
		panic(err)
	}

	b, err := (rows)
	if err != nil {
		panic(err)
	}
	("gorm", string(b))
}
package main

import (
	"log"
	"os"
	"testing"
	"time"

	"/stretchr/testify/assert"
	"/driver/sqlite"
	"/gorm"
	"/gorm/logger"
)

func Test_handleOrm(t *) {
	db := NewMemoryDB()
	err := ().CreateTable(&Test{})
	(t, err)

	handleOrm(db)
}

func NewMemoryDB() * {
	var db *
	var err error
	newLogger := (
		(, "\r\n", ), // io writer
		{
			SlowThreshold: , // 慢 SQL 阈值
			LogLevel:      , // Log level
			Colorful:      false,       // 禁用彩色打印
		},
	)
	dialector := (":memory:?cache=shared")
	if db, err = (dialector, &{
		Logger: newLogger,
	}); err != nil {
		panic(err)
	}
	dba, err := ()
	(1)
	return db
}

func CloseMemoryDB(db *) {
	sqlDB, _ := ()
	()
}

http mock

package main

import (
	"fmt"
	"net/http"
	"time"
)

func Send() (err error) {
	req, err := (, "https://127.0.0.1:8080", nil)
	if err != nil {
		return
	}
	client := &{
		Timeout: ,
	}
	resp, err := (req)
	if err != nil {
		return
	}
	defer ()

	if  !=  {
		return ("HTTP status is %d", )
	}

	return
}
package main

import (
	"net/http"
	"testing"

	"/jarcoal/httpmock"
	"/smartystreets/goconvey/convey"
	"/stretchr/testify/assert"
)

func TestSend(t *) {
	("TestSend", t, func() {
		("success", func() {
			()
			defer ()
			(, "https://127.0.0.1:8080", (, ""))

			err := Send()
			(t, err)
		})
	})
}