源码分享-golang的二进制文件读写库
库功能
功能类似golang标准库encoding/binary
,用于二进制码流/文件的读写。对比标准库,本库对以下方面做了功能增强:
- 支持bit级别的结构体成员编解码
- 支持
bit
、sort
结构体标签,用于指定结构体成员的bit数、大小端属性 - 支持
map
类型结构 - 支持
string
字符串 - 支持自定义方法编解码
库源码
decode.go
package binary
import (
"fmt"
"io"
"math"
"reflect"
)
type Unmarshaler interface {
UnmarshalBinary(dec *Decoder, isBig bool, bit int) error
}
func Unmarshal(buf []byte, isBig bool, e ...any) error {
return NewDecoder(buf, 0).Unmarshal(isBig, e...)
}
var EOF = newErr("read the end")
type Decoder struct {
buf []byte
prevOff int
off int
bit int
arg any
}
func NewDecoder(buf []byte, prevOff int) *Decoder {
return &Decoder{
buf: buf, prevOff: prevOff}
}
func NewReaderDecoder(r io.Reader) *Decoder {
buf, _ := io.ReadAll(r)
return NewDecoder(buf, 0)
}
func (dec *Decoder) Pos() int {
return (dec.prevOff+dec.off)*8 + dec.bit
}
func (dec *Decoder) Seek(pos int) error {
if pos < dec.prevOff*8 || pos > (dec.prevOff+len(dec.buf))*8 {
return fmtErr("pos(%d.%d) illegal", pos/8, pos%8)
}
dec.off = pos/8 - dec.prevOff
dec.bit = pos % 8
return nil
}
func (dec *Decoder) SubDecoder(n int) *Decoder {
prevOff := dec.off
if dec.bit > 0 {
prevOff++
}
if n <= 0 || prevOff+n >= len(dec.buf) {
n = len(dec.buf) - prevOff
}
return &Decoder{
buf: dec.buf[prevOff : prevOff+n],
prevOff: dec.prevOff + prevOff,
arg: dec.arg,
}
}
func (dec *Decoder) SetArg(a any) {
dec.arg = a
}
func (dec *Decoder) Arg() any {
return dec.arg
}
func (dec *Decoder) Read(p []byte) (n int, err error) {
if dec.bit > 0 {
dec.bit = 0
dec.off++
}
if len(p) == 0 {
return 0, nil
}
n = len(dec.buf) - dec.off
if n <= 0 {
return 0, EOF
}
if n > len(p) {
n = len(p)
}
copy(p, dec.buf[dec.off:dec.off+n])
dec.off += n
return n, nil
}
func (dec *Decoder) ReadByte() (byte, error) {
if dec.bit > 0 {
dec.bit = 0
dec.off++
}
if dec.off >= len(dec.buf) {
return 0, EOF
}
b := dec.buf[dec.off]
dec.off++
return b, nil
}
func (dec *Decoder) ReadBytes(n int) []byte {
if dec.bit > 0 {
dec.bit = 0
dec.off++
}
if dec.off >= len(dec.buf) {
panic(EOF)
}
if n <= 0 || dec.off+n > len(dec.buf) {
n = len(dec.buf) - dec.off
}
bs := dec.buf[dec.off : dec.off+n]
dec.off += n
return bs
}
func (dec *Decoder) IsEof() bool {
if dec.bit > 0 {
return false
}
return dec.off == len(dec.buf)
}
func (dec *Decoder) Unmarshal(isBig bool, e ...any) error {
for _, v := range e {
err := dec.Decode(v, isBig, 0)
if err != nil {
return err
}
}
return nil
}
func (dec *Decoder) Decode(a any, isBig bool, bit int) (err error) {
defer func() {
r := recover()
if r == nil {
return
}
if e, ok := r.(decErr); ok {
err = e
} else {
panic(r)
}
}()
if a == nil {
return
}
if v, ok := a.(reflect.Value); ok {
dec.decode(v, isBig, bit)
}
dec.decode(reflect.ValueOf(a), isBig, bit)
return
}
func (dec *Decoder) decode(v reflect.Value, isBig bool, bit int) {
if !v.IsValid() {
return
}
if v.CanInterface() && dec.handleMethods(v, isBig, bit) {
return
}
if v.Kind() != reflect.Pointer && v.CanAddr() && dec.handleMethods(v.Addr(), isBig, bit) {
return
}
switch v.Kind() {
case reflect.Bool:
if v.CanSet() {
v.SetBool(dec.decodeInteger(8, isBig, bit) != 0)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if v.CanSet() {
v.SetInt(int64(dec.decodeInteger(kindSize(v.Kind()), isBig, bit)))
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
if v.CanSet() {
v.SetUint(dec.decodeInteger(kindSize(v.Kind()), isBig, bit))
}
case reflect.Float32, reflect.Float64:
if v.CanSet() {
v.SetFloat(dec.decodeFloat(kindSize(v.Kind()), isBig))
}
case reflect.Complex64, reflect.Complex128:
if v.CanSet() {
v.SetComplex(dec.decodeComplex(kindSize(v.Kind()), isBig))
}
case reflect.Slice:
if v.IsNil() {
dec.decodeSlice(v, isBig, bit)
break
}
fallthrough
case reflect.Array:
l := v.Len()
for i := 0; i < l; i++ {
dec.decode(v.Index(i), isBig, bit)
}
case reflect.Interface, reflect.Pointer:
if v.IsNil() {
break
}
dec.decode(v.Elem(), isBig, bit)
case reflect.Map:
dec.decodeMap