在redis/client/client.go
主要是客户端处理
package client
const (
created = iota
running
closed
)
type B struct {
data chan string
ticker *time.Ticker
}
// Client is a pipeline mode redis client
type Client struct {
conn net.Conn
// 等待发送
pendingReqs chan *request // wait to send
// 等待响应
waitingReqs chan *request // waiting response
ticker *time.Ticker
addr string
status int32
working *sync.WaitGroup // its counter presents unfinished requests(pending and waiting)
}
// 这个一个发送到redis的请求结构
// request is a message sends to redis server
type request struct {
id uint64
args [][]byte
reply redis.Reply
heartbeat bool
waiting *wait.Wait
err error
}
const (
chanSize = 256
maxWait = 3 * time.Second
)
// MakeClient creates a new client
func MakeClient(addr string) (*Client, error) {
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
return &Client{
addr: addr,
conn: conn,
pendingReqs: make(chan *request, chanSize),
waitingReqs: make(chan *request, chanSize),
working: &sync.WaitGroup{},
}, nil
}
// 开始启动异步程序
// Start starts asynchronous goroutines
func (client *Client) Start() {
client.ticker = time.NewTicker(10 * time.Second)
// 每个方法都会监听channel
go client.handleWrite()
go client.handleRead()
go client.heartbeat()
atomic.StoreInt32(&client.status, running)
}
// 异步关闭客户端
// Close stops asynchronous goroutines and close connection
func (client *Client) Close() {
atomic.StoreInt32(&client.status, closed)
client.ticker.Stop()
// stop new request
close(client.pendingReqs)
// wait stop process
client.working.Wait()
// clean
_ = client.conn.Close()
close(client.waitingReqs)
}
// 重新连接
func (client *Client) reconnect() {
logger.Info("reconnect with: " + client.addr)
_ = client.conn.Close() // ignore possible errors from repeated closes
var conn net.Conn
for i := 0; i < 3; i++ {
var err error
conn, err = net.Dial("tcp", client.addr)
if err != nil {
logger.Error("reconnect error: " + err.Error())
time.Sleep(time.Second)
continue
} else {
break
}
}
if conn == nil { // reach max retry, abort
client.Close()
return
}
client.conn = conn
close(client.waitingReqs)
for req := range client.waitingReqs {
req.err = errors.New("connection closed")
req.waiting.Done()
}
client.waitingReqs = make(chan *request, chanSize)
// restart handle read
go client.handleRead()
}
// 监听发送心跳
func (client *Client) heartbeat() {
for range client.ticker.C {
client.doHeartbeat()
}
}
// 写入监听
func (client *Client) handleWrite() {
for req := range client.pendingReqs {
client.doRequest(req)
}
}
// 发送一个请求到redis服务器
// Send sends a request to redis server
func (client *Client) Send(args [][]byte) redis.Reply {
if atomic.LoadInt32(&client.status) != running {
return protocol.MakeErrReply("client closed")
}
req := &request{
args: args,
heartbeat: false,
waiting: &wait.Wait{},
}
req.waiting.Add(1)
client.working.Add(1)
defer client.working.Done()
// 放入
client.pendingReqs <- req
timeout := req.waiting.WaitWithTimeout(maxWait)
if timeout {
return protocol.MakeErrReply("server time out")
}
if req.err != nil {
return protocol.MakeErrReply("request failed " + req.err.Error())
}
return req.reply
}
// 心跳
func (client *Client) doHeartbeat() {
request := &request{
args: [][]byte{[]byte("PING")},
heartbeat: true,
waiting: &wait.Wait{},
}
request.waiting.Add(1)
client.working.Add(1)
defer client.working.Done()
client.pendingReqs <- request
request.waiting.WaitWithTimeout(maxWait)
}
func (client *Client) doRequest(req *request) {
if req == nil || len(req.args) == 0 {
return
}
// 数据转换为byte
re := protocol.MakeMultiBulkReply(req.args)
bytes := re.ToBytes()
var err error
// 三次重试
for i := 0; i < 3; i++ { // only retry, waiting for handleRead
_, err = client.conn.Write(bytes)
if err == nil ||
(!strings.Contains(err.Error(), "timeout") && // only retry timeout
!strings.Contains(err.Error(), "deadline exceeded")) {
break
}
}
if err == nil {
// 成功发送通知
client.waitingReqs <- req
} else {
req.err = err
req.waiting.Done()
}
}
// 完成请求
func (client *Client) finishRequest(reply redis.Reply) {
defer func() {
if err := recover(); err != nil {
debug.PrintStack()
logger.Error(err)
}
}()
request := <-client.waitingReqs
if request == nil {
return
}
request.reply = reply
if request.waiting != nil {
request.waiting.Done()
}
}
// 处理响应数据
func (client *Client) handleRead() {
// 数据转义
ch := parser.ParseStream(client.conn)
for payload := range ch {
// 检查消息体有没有错误
if payload.Err != nil {
status := atomic.LoadInt32(&client.status)
if status == closed {
return
}
client.reconnect()
return
}
client.finishRequest(payload.Data)
}
}
在redis/conn/conn.go
TCP连接方法管理
import (
"net"
"sync"
"time"
"github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/lib/sync/wait"
)
const (
// flagSlave means this a connection with slave
flagSlave = uint64(1 << iota)
// flagSlave means this a connection with master
flagMaster
// flagMulti means this connection is within a transaction
flagMulti
)
// Connection represents a connection with a redis-cli
type Connection struct {
conn net.Conn
// wait until finish sending data, used for graceful shutdown
sendingData wait.Wait
// lock while server sending response
mu sync.Mutex
flags uint64
// subscribing channels
subs map[string]bool
// password may be changed by CONFIG command during runtime, so store the password
password string
// queued commands for `multi`
queue [][][]byte
watching map[string]uint32
txErrors []error
// selected db
selectedDB int
}
// 连接池
var connPool = sync.Pool{
New: func() interface{} {
return &Connection{}
},
}
// 返回远程地址
// RemoteAddr returns the remote network address
func (c *Connection) RemoteAddr() string {
return c.conn.RemoteAddr().String()
}
// Close disconnect with the client
func (c *Connection) Close() error {
c.sendingData.WaitWithTimeout(10 * time.Second)
_ = c.conn.Close()
c.subs = nil
c.password = ""
c.queue = nil
c.watching = nil
c.txErrors = nil
c.selectedDB = 0
connPool.Put(c)
return nil
}
// 创建一个连接实例
// NewConn creates Connection instance
func NewConn(conn net.Conn) *Connection {
// 从线程池去
c, ok := connPool.Get().(*Connection)
if !ok {
logger.Error("connection pool make wrong type")
return &Connection{
conn: conn,
}
}
c.conn = conn
return c
}
// Write sends response to client over tcp connection
func (c *Connection) Write(b []byte) (int, error) {
if len(b) == 0 {
return 0, nil
}
c.sendingData.Add(1)
defer func() {
c.sendingData.Done()
}()
return c.conn.Write(b)
}
// 获取连接名称
func (c *Connection) Name() string {
if c.conn != nil {
return c.conn.RemoteAddr().String()
}
return ""
}
// 订阅放入map
// Subscribe add current connection into subscribers of the given channel
func (c *Connection) Subscribe(channel string) {
c.mu.Lock()
defer c.mu.Unlock()
if c.subs == nil {
c.subs = make(map[string]bool)
}
c.subs[channel] = true
}
// 订阅删除
// UnSubscribe removes current connection into subscribers of the given channel
func (c *Connection) UnSubscribe(channel string) {
c.mu.Lock()
defer c.mu.Unlock()
if len(c.subs) == 0 {
return
}
delete(c.subs, channel)
}
// 获取订阅集合长度
// SubsCount returns the number of subscribing channels
func (c *Connection) SubsCount() int {
return len(c.subs)
}
// GetChannels returns all subscribing channels
func (c *Connection) GetChannels() []string {
if c.subs == nil {
return make([]string, 0)
}
channels := make([]string, len(c.subs))
i := 0
for channel := range c.subs {
channels[i] = channel
i++
}
return channels
}
// 设置密码
// SetPassword stores password for authentication
func (c *Connection) SetPassword(password string) {
c.password = password
}
// 获取密码
// GetPassword get password for authentication
func (c *Connection) GetPassword() string {
return c.password
}
// 获取可变状态
// InMultiState tells is connection in an uncommitted transaction
func (c *Connection) InMultiState() bool {
return c.flags&flagMulti > 0
}
// 设置可变状态
// SetMultiState sets transaction flag
func (c *Connection) SetMultiState(state bool) {
if !state { // reset data when cancel multi
c.watching = nil
c.queue = nil
c.flags &= ^flagMulti // clean multi flag
return
}
c.flags |= flagMulti
}
// 返回当前事务的队列命令
// GetQueuedCmdLine returns queued commands of current transaction
func (c *Connection) GetQueuedCmdLine() [][][]byte {
return c.queue
}
// 命令加入队列
// EnqueueCmd enqueues command of current transaction
func (c *Connection) EnqueueCmd(cmdLine [][]byte) {
c.queue = append(c.queue, cmdLine)
}
// AddTxError stores syntax error within transaction
func (c *Connection) AddTxError(err error) {
c.txErrors = append(c.txErrors, err)
}
// GetTxErrors returns syntax error within transaction
func (c *Connection) GetTxErrors() []error {
return c.txErrors
}
// ClearQueuedCmds clears queued commands of current transaction
func (c *Connection) ClearQueuedCmds() {
c.queue = nil
}
// GetWatching returns watching keys and their version code when started watching
func (c *Connection) GetWatching() map[string]uint32 {
if c.watching == nil {
c.watching = make(map[string]uint32)
}
return c.watching
}
// GetDBIndex returns selected db
func (c *Connection) GetDBIndex() int {
return c.selectedDB
}
// SelectDB selects a database
func (c *Connection) SelectDB(dbNum int) {
c.selectedDB = dbNum
}
func (c *Connection) SetSlave() {
c.flags |= flagSlave
}
func (c *Connection) IsSlave() bool {
return c.flags&flagSlave > 0
}
func (c *Connection) SetMaster() {
c.flags |= flagMaster
}
func (c *Connection) IsMaster() bool {
return c.flags&flagMaster > 0
}
在redis/conn/fake.go
假连接,用于测试
在redis/parser/parser.go
用于解析客户端发来的数据
package parser
import (
"bufio"
"bytes"
"errors"
"io"
"runtime/debug"
"strconv"
"strings"
"github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/redis/protocol"
)
// 消息体结构
// Payload stores redis.Reply or error
type Payload struct {
Data redis.Reply
Err error
}
// 解析从io流的数据
// ParseStream reads data from io.Reader and send payloads through channel
func ParseStream(reader io.Reader) <-chan *Payload {
ch := make(chan *Payload)
go parse0(reader, ch)
return ch
}
// 解析byte
// ParseBytes reads data from []byte and return all replies
func ParseBytes(data []byte) ([]redis.Reply, error) {
ch := make(chan *Payload)
reader := bytes.NewReader(data)
go parse0(reader, ch)
var results []redis.Reply
for payload := range ch {
if payload == nil {
return nil, errors.New("no protocol")
}
if payload.Err != nil {
if payload.Err == io.EOF {
break
}
return nil, payload.Err
}
results = append(results, payload.Data)
}
return results, nil
}
// 解析第一个消息体
// ParseOne reads data from []byte and return the first payload
func ParseOne(data []byte) (redis.Reply, error) {
ch := make(chan *Payload)
reader := bytes.NewReader(data)
go parse0(reader, ch)
payload := <-ch // parse0 will close the channel
if payload == nil {
return nil, errors.New("no protocol")
}
return payload.Data, payload.Err
}
// 私有方法,
func parse0(rawReader io.Reader, ch chan<- *Payload) {
// 最后判断有无错误,有则打印日志
defer func() {
if err := recover(); err != nil {
logger.Error(err, string(debug.Stack()))
}
}()
// 解析流
reader := bufio.NewReader(rawReader)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
// 异常处理
ch <- &Payload{Err: err}
close(ch)
return
}
// 解析长度
length := len(line)
// 过短异常
if length <= 2 || line[length-2] != '\r' {
// there are some empty lines within replication traffic, ignore this error
//protocolError(ch, "empty line")
continue
}
line = bytes.TrimSuffix(line, []byte{'\r', '\n'})
// 根据不同的字符,做不同的解析方法,ASCII判断
switch line[0] {
case '+':
content := string(line[1:])
ch <- &Payload{
Data: protocol.MakeStatusReply(content),
}
if strings.HasPrefix(content, "FULLRESYNC") {
err = parseRDBBulkString(reader, ch)
if err != nil {
ch <- &Payload{Err: err}
close(ch)
return
}
}
case '-':
ch <- &Payload{
Data: protocol.MakeErrReply(string(line[1:])),
}
case ':':
value, err := strconv.ParseInt(string(line[1:]), 10, 64)
if err != nil {
protocolError(ch, "illegal number "+string(line[1:]))
continue
}
ch <- &Payload{
Data: protocol.MakeIntReply(value),
}
case '$':
err = parseBulkString(line, reader, ch)
if err != nil {
ch <- &Payload{Err: err}
close(ch)
return
}
case '*':
err = parseArray(line, reader, ch)
if err != nil {
ch <- &Payload{Err: err}
close(ch)
return
}
default:
args := bytes.Split(line, []byte{' '})
ch <- &Payload{
Data: protocol.MakeMultiBulkReply(args),
}
}
}
}
// 解析字符串
func parseBulkString(header []byte, reader *bufio.Reader, ch chan<- *Payload) error {
strLen, err := strconv.ParseInt(string(header[1:]), 10, 64)
if err != nil || strLen < -1 {
protocolError(ch, "illegal bulk string header: "+string(header))
return nil
} else if strLen == -1 {
ch <- &Payload{
Data: protocol.MakeNullBulkReply(),
}
return nil
}
body := make([]byte, strLen+2)
_, err = io.ReadFull(reader, body)
if err != nil {
return err
}
ch <- &Payload{
Data: protocol.MakeBulkReply(body[:len(body)-2]),
}
return nil
}
// RDB和后续AOF之间没有CRLF,因此需要区别对待
// there is no CRLF between RDB and following AOF, therefore it needs to be treated differently
func parseRDBBulkString(reader *bufio.Reader, ch chan<- *Payload) error {
header, err := reader.ReadBytes('\n')
header = bytes.TrimSuffix(header, []byte{'\r', '\n'})
if len(header) == 0 {
return errors.New("empty header")
}
strLen, err := strconv.ParseInt(string(header[1:]), 10, 64)
if err != nil || strLen <= 0 {
return errors.New("illegal bulk header: " + string(header))
}
body := make([]byte, strLen)
_, err = io.ReadFull(reader, body)
if err != nil {
return err
}
ch <- &Payload{
Data: protocol.MakeBulkReply(body[:len(body)]),
}
return nil
}
func parseArray(header []byte, reader *bufio.Reader, ch chan<- *Payload) error {
nStrs, err := strconv.ParseInt(string(header[1:]), 10, 64)
// nStrs > 0为合法
if err != nil || nStrs < 0 {
protocolError(ch, "illegal array header "+string(header[1:]))
return nil
} else if nStrs == 0 {
ch <- &Payload{
Data: protocol.MakeEmptyMultiBulkReply(),
}
return nil
}
// 消息合法判断
lines := make([][]byte, 0, nStrs)
for i := int64(0); i < nStrs; i++ {
var line []byte
line, err = reader.ReadBytes('\n')
if err != nil {
return err
}
length := len(line)
if length < 4 || line[length-2] != '\r' || line[0] != '$' {
protocolError(ch, "illegal bulk string header "+string(line))
break
}
strLen, err := strconv.ParseInt(string(line[1:length-2]), 10, 64)
if err != nil || strLen < -1 {
protocolError(ch, "illegal bulk string length "+string(line))
break
} else if strLen == -1 {
lines = append(lines, []byte{})
} else {
body := make([]byte, strLen+2)
_, err := io.ReadFull(reader, body)
if err != nil {
return err
}
lines = append(lines, body[:len(body)-2])
}
}
// 合法消息装入通道
ch <- &Payload{
Data: protocol.MakeMultiBulkReply(lines),
}
return nil
}
func protocolError(ch chan<- *Payload, msg string) {
err := errors.New("protocol error: " + msg)
ch <- &Payload{Err: err}
}
在redis/protocol/asserts/asserts.go
用于测试检查
在redis/protocol/consts.go
定义的一些常量
在redis/protocol/errors.go
定义的一些错误
在redis/protocol/reply.go
协议消息返回
在redis/server/server.go
TCP服务接收到连接后,异步拉起服务,用于客户端的消息处理