/* * 版权所有 (c) 上海元泓软件科技有限公司 2022. * 严禁通过任何媒介未经授权复制本文件. * * 作者:mic * Email:funui@outlook.com */ package YHProto import ( "bytes" "encoding/binary" "encoding/json" "errors" "fmt" "github.com/gobwas/ws/wsutil" "github.com/panjf2000/gnet/v2" "github.com/zeromicro/go-zero/core/logx" "hash/crc32" "time" "unsafe" ) //////////////////////////////////////////// // echo command //////////////////////////////////////////// type echoCommand uint32 var echoCommandStrings = [...]string{ "heartbeat", "echo", } func (c echoCommand) String() string { return echoCommandStrings[c] } var ( echoPkgMagic uint32 = 0x20160905 maxEchoStringLen uint32 = 0xffffff // max string length 16MB echoPkgKey = "example key 1234" ) var ( ErrNotEnoughStream = errors.New("packet stream is not enough") ErrTooLargePackage = errors.New("package length is exceed the echo package's legal maximum length") ErrIllegalMagic = errors.New("package magic is not right") ErrTimeOut = errors.New("数据包已过期") ErrCheckSum = errors.New("数据校验和失败") ) var echoPkgHeaderLen int func init() { echoPkgHeaderLen = (int)((uint)(unsafe.Sizeof(EchoPkgHeader{}))) } func SetEchoPkgMagic(magic uint32) { echoPkgMagic = magic } func SetEchoPkgKey(key string) { echoPkgKey = key } type ( // EchoPkgHeader //* 字节位置 4 6 8 12 16 20 //* +----------------+----------+----------+----------------+----------------+----------------+-------------------+ //* | Magic(4) | Seq(2) | Cmd(2) | UT(4) | ChkSum(4) | Len(4) | body | //* +----------------+----------+----------+----------------+----------------+----------------+-------------------+ //* | 消息头部分 | 消息体... | //* +----------------+----------+----------+----------------+----------------+----------------+-------------------+ EchoPkgHeader struct { Magic uint32 //包头,协议版本号 Sequence uint16 // 请求序列 Command uint16 // 命令编号 UT uint32 // 客户端本地时间戳(秒) ChkSum uint32 // 检验和 Len uint32 // 数据包长度 } EchoPackage struct { H EchoPkgHeader B json.RawMessage } ) func (p EchoPkgHeader) String() string { return fmt.Sprintf("请求ID:%d, 命令ID:%s, 数据包长度:%d", p.Sequence, (echoCommand(p.Command)).String(), p.Len) } func (codec *EchoPackage) VerifyTime() bool { // 校验消息时间戳是否过期 return time.Now().Unix()-int64(codec.H.UT) < int64(time.Minute.Seconds()*5) } func (codec *EchoPackage) VerifyBodySize() bool { // 防止恶意客户端把这个字段设置过大导致服务端死等或者服务端在准备对应的缓冲区时内存崩溃 if maxEchoStringLen < codec.H.Len { } return maxEchoStringLen > codec.H.Len } func (codec *EchoPackage) CheckHeader() error { if codec.VerifyBodySize() { return ErrTooLargePackage } if codec.VerifyTime() { return ErrTimeOut } return nil } func (codec *EchoPackage) CalculateCheckSum(body []byte) uint32 { // 计算消息的校验和 buffer := bytes.NewBuffer(nil) binary.Write(buffer, binary.LittleEndian, codec.H.Sequence) binary.Write(buffer, binary.LittleEndian, codec.H.Command) binary.Write(buffer, binary.LittleEndian, codec.H.UT) buffer.Write(body) checkSum := crc32.ChecksumIEEE(buffer.Bytes()) return checkSum } // Encode 编码 func (codec *EchoPackage) Encode() ([]byte, error) { codec.H.Magic = echoPkgMagic codec.H.UT = uint32(time.Now().Unix()) buf := bytes.NewBuffer(nil) err := binary.Write(buf, binary.LittleEndian, codec.H) if err != nil { return nil, err } //加密body encryptedBody, err := Encrypt(codec.B) if err != nil { return nil, err } codec.H.Len = uint32(len(codec.B)) //计算校验和 codec.H.ChkSum = codec.CalculateCheckSum(encryptedBody) buf.Write(encryptedBody) return buf.Bytes(), nil } // Decode 解码buffer func (codec *EchoPackage) DecodeBuffer(msg []byte) (err error) { if len(msg) < echoPkgHeaderLen { return ErrNotEnoughStream } hBuf := msg[:echoPkgHeaderLen] codec.H.Magic = binary.LittleEndian.Uint32(hBuf[0:4]) if codec.H.Magic != echoPkgMagic { logx.Errorf("@p.H.Magic{%x}, right magic{%x}", codec.H.Magic, echoPkgMagic) return ErrIllegalMagic } codec.H.Sequence = binary.LittleEndian.Uint16(hBuf[4:6]) codec.H.Command = binary.LittleEndian.Uint16(hBuf[6:8]) codec.H.UT = binary.LittleEndian.Uint32(hBuf[8:12]) codec.H.ChkSum = binary.LittleEndian.Uint32(hBuf[12:16]) codec.H.Len = binary.LittleEndian.Uint32(hBuf[16:20]) // 防止恶意客户端把这个字段设置过大导致服务端死等或者服务端在准备对应的缓冲区时内存崩溃 err = codec.CheckHeader() if err != nil { return err } body := msg[echoPkgHeaderLen : echoPkgHeaderLen+int(codec.H.Len)] // 校验和验证 checkSum := codec.CalculateCheckSum(body) if checkSum != codec.H.ChkSum { return ErrCheckSum } //解密 codec.B, err = Decrypt(body) if err != nil { return err } return nil } // Decode 解码gnet func (codec *EchoPackage) DecodeGnet(c gnet.Conn) (err error) { if c.InboundBuffered() < echoPkgHeaderLen { return ErrNotEnoughStream } hBuf, _ := c.Peek(echoPkgHeaderLen) codec.H.Magic = binary.LittleEndian.Uint32(hBuf[0:4]) if codec.H.Magic != echoPkgMagic { logx.Errorf("@p.H.Magic{%x}, right magic{%x}", codec.H.Magic, echoPkgMagic) return ErrIllegalMagic } codec.H.Sequence = binary.LittleEndian.Uint16(hBuf[4:6]) codec.H.Command = binary.LittleEndian.Uint16(hBuf[6:8]) codec.H.UT = binary.LittleEndian.Uint32(hBuf[8:12]) codec.H.ChkSum = binary.LittleEndian.Uint32(hBuf[12:16]) codec.H.Len = binary.LittleEndian.Uint32(hBuf[16:20]) // 防止恶意客户端把这个字段设置过大导致服务端死等或者服务端在准备对应的缓冲区时内存崩溃 if maxEchoStringLen < codec.H.Len { return ErrTooLargePackage } msgLen := echoPkgHeaderLen + int(codec.H.Len) if c.InboundBuffered() < msgLen { return ErrNotEnoughStream } _, _ = c.Discard(echoPkgHeaderLen) bBuf, _ := c.Peek(int(codec.H.Len)) codec.B = make([]byte, codec.H.Len) copy(codec.B, bBuf) _, _ = c.Discard(int(codec.H.Len)) return nil } // Decode 解码wsutil func (codec *EchoPackage) DecodeWsMessage(msg wsutil.Message) (err error) { if len(msg.Payload) < echoPkgHeaderLen { return ErrNotEnoughStream } hBuf := msg.Payload[:echoPkgHeaderLen] codec.H.Magic = binary.LittleEndian.Uint32(hBuf[0:4]) if codec.H.Magic != echoPkgMagic { logx.Errorf("@p.H.Magic{%x}, right magic{%x}", codec.H.Magic, echoPkgMagic) return ErrIllegalMagic } codec.H.Sequence = binary.LittleEndian.Uint16(hBuf[4:6]) codec.H.Command = binary.LittleEndian.Uint16(hBuf[6:8]) codec.H.UT = binary.LittleEndian.Uint32(hBuf[8:12]) codec.H.ChkSum = binary.LittleEndian.Uint32(hBuf[12:16]) codec.H.Len = binary.LittleEndian.Uint32(hBuf[16:20]) // 防止恶意客户端把这个字段设置过大导致服务端死等或者服务端在准备对应的缓冲区时内存崩溃 if maxEchoStringLen < codec.H.Len { return ErrTooLargePackage } if len(msg.Payload)-echoPkgHeaderLen < int(codec.H.Len) { return ErrNotEnoughStream } codec.B = make([]byte, codec.H.Len) copy(codec.B, msg.Payload[echoPkgHeaderLen:echoPkgHeaderLen+int(codec.H.Len)]) return nil }