576 lines
15 KiB
Go
576 lines
15 KiB
Go
// Copyright 2019 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package impl
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
"math/bits"
|
|
"reflect"
|
|
"unicode/utf8"
|
|
|
|
"google.golang.org/protobuf/encoding/protowire"
|
|
"google.golang.org/protobuf/internal/encoding/messageset"
|
|
"google.golang.org/protobuf/internal/flags"
|
|
"google.golang.org/protobuf/internal/genid"
|
|
"google.golang.org/protobuf/internal/strs"
|
|
"google.golang.org/protobuf/reflect/protoreflect"
|
|
"google.golang.org/protobuf/reflect/protoregistry"
|
|
"google.golang.org/protobuf/runtime/protoiface"
|
|
)
|
|
|
|
// ValidationStatus is the result of validating the wire-format encoding of a message.
|
|
type ValidationStatus int
|
|
|
|
const (
|
|
// ValidationUnknown indicates that unmarshaling the message might succeed or fail.
|
|
// The validator was unable to render a judgement.
|
|
//
|
|
// The only causes of this status are an aberrant message type appearing somewhere
|
|
// in the message or a failure in the extension resolver.
|
|
ValidationUnknown ValidationStatus = iota + 1
|
|
|
|
// ValidationInvalid indicates that unmarshaling the message will fail.
|
|
ValidationInvalid
|
|
|
|
// ValidationValid indicates that unmarshaling the message will succeed.
|
|
ValidationValid
|
|
)
|
|
|
|
func (v ValidationStatus) String() string {
|
|
switch v {
|
|
case ValidationUnknown:
|
|
return "ValidationUnknown"
|
|
case ValidationInvalid:
|
|
return "ValidationInvalid"
|
|
case ValidationValid:
|
|
return "ValidationValid"
|
|
default:
|
|
return fmt.Sprintf("ValidationStatus(%d)", int(v))
|
|
}
|
|
}
|
|
|
|
// Validate determines whether the contents of the buffer are a valid wire encoding
|
|
// of the message type.
|
|
//
|
|
// This function is exposed for testing.
|
|
func Validate(mt protoreflect.MessageType, in protoiface.UnmarshalInput) (out protoiface.UnmarshalOutput, _ ValidationStatus) {
|
|
mi, ok := mt.(*MessageInfo)
|
|
if !ok {
|
|
return out, ValidationUnknown
|
|
}
|
|
if in.Resolver == nil {
|
|
in.Resolver = protoregistry.GlobalTypes
|
|
}
|
|
o, st := mi.validate(in.Buf, 0, unmarshalOptions{
|
|
flags: in.Flags,
|
|
resolver: in.Resolver,
|
|
})
|
|
if o.initialized {
|
|
out.Flags |= protoiface.UnmarshalInitialized
|
|
}
|
|
return out, st
|
|
}
|
|
|
|
type validationInfo struct {
|
|
mi *MessageInfo
|
|
typ validationType
|
|
keyType, valType validationType
|
|
|
|
// For non-required fields, requiredBit is 0.
|
|
//
|
|
// For required fields, requiredBit's nth bit is set, where n is a
|
|
// unique index in the range [0, MessageInfo.numRequiredFields).
|
|
//
|
|
// If there are more than 64 required fields, requiredBit is 0.
|
|
requiredBit uint64
|
|
}
|
|
|
|
type validationType uint8
|
|
|
|
const (
|
|
validationTypeOther validationType = iota
|
|
validationTypeMessage
|
|
validationTypeGroup
|
|
validationTypeMap
|
|
validationTypeRepeatedVarint
|
|
validationTypeRepeatedFixed32
|
|
validationTypeRepeatedFixed64
|
|
validationTypeVarint
|
|
validationTypeFixed32
|
|
validationTypeFixed64
|
|
validationTypeBytes
|
|
validationTypeUTF8String
|
|
validationTypeMessageSetItem
|
|
)
|
|
|
|
func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
|
|
var vi validationInfo
|
|
switch {
|
|
case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
|
|
switch fd.Kind() {
|
|
case protoreflect.MessageKind:
|
|
vi.typ = validationTypeMessage
|
|
if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
|
|
vi.mi = getMessageInfo(ot.Field(0).Type)
|
|
}
|
|
case protoreflect.GroupKind:
|
|
vi.typ = validationTypeGroup
|
|
if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
|
|
vi.mi = getMessageInfo(ot.Field(0).Type)
|
|
}
|
|
case protoreflect.StringKind:
|
|
if strs.EnforceUTF8(fd) {
|
|
vi.typ = validationTypeUTF8String
|
|
}
|
|
}
|
|
default:
|
|
vi = newValidationInfo(fd, ft)
|
|
}
|
|
if fd.Cardinality() == protoreflect.Required {
|
|
// Avoid overflow. The required field check is done with a 64-bit mask, with
|
|
// any message containing more than 64 required fields always reported as
|
|
// potentially uninitialized, so it is not important to get a precise count
|
|
// of the required fields past 64.
|
|
if mi.numRequiredFields < math.MaxUint8 {
|
|
mi.numRequiredFields++
|
|
vi.requiredBit = 1 << (mi.numRequiredFields - 1)
|
|
}
|
|
}
|
|
return vi
|
|
}
|
|
|
|
func newValidationInfo(fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
|
|
var vi validationInfo
|
|
switch {
|
|
case fd.IsList():
|
|
switch fd.Kind() {
|
|
case protoreflect.MessageKind:
|
|
vi.typ = validationTypeMessage
|
|
if ft.Kind() == reflect.Slice {
|
|
vi.mi = getMessageInfo(ft.Elem())
|
|
}
|
|
case protoreflect.GroupKind:
|
|
vi.typ = validationTypeGroup
|
|
if ft.Kind() == reflect.Slice {
|
|
vi.mi = getMessageInfo(ft.Elem())
|
|
}
|
|
case protoreflect.StringKind:
|
|
vi.typ = validationTypeBytes
|
|
if strs.EnforceUTF8(fd) {
|
|
vi.typ = validationTypeUTF8String
|
|
}
|
|
default:
|
|
switch wireTypes[fd.Kind()] {
|
|
case protowire.VarintType:
|
|
vi.typ = validationTypeRepeatedVarint
|
|
case protowire.Fixed32Type:
|
|
vi.typ = validationTypeRepeatedFixed32
|
|
case protowire.Fixed64Type:
|
|
vi.typ = validationTypeRepeatedFixed64
|
|
}
|
|
}
|
|
case fd.IsMap():
|
|
vi.typ = validationTypeMap
|
|
switch fd.MapKey().Kind() {
|
|
case protoreflect.StringKind:
|
|
if strs.EnforceUTF8(fd) {
|
|
vi.keyType = validationTypeUTF8String
|
|
}
|
|
}
|
|
switch fd.MapValue().Kind() {
|
|
case protoreflect.MessageKind:
|
|
vi.valType = validationTypeMessage
|
|
if ft.Kind() == reflect.Map {
|
|
vi.mi = getMessageInfo(ft.Elem())
|
|
}
|
|
case protoreflect.StringKind:
|
|
if strs.EnforceUTF8(fd) {
|
|
vi.valType = validationTypeUTF8String
|
|
}
|
|
}
|
|
default:
|
|
switch fd.Kind() {
|
|
case protoreflect.MessageKind:
|
|
vi.typ = validationTypeMessage
|
|
if !fd.IsWeak() {
|
|
vi.mi = getMessageInfo(ft)
|
|
}
|
|
case protoreflect.GroupKind:
|
|
vi.typ = validationTypeGroup
|
|
vi.mi = getMessageInfo(ft)
|
|
case protoreflect.StringKind:
|
|
vi.typ = validationTypeBytes
|
|
if strs.EnforceUTF8(fd) {
|
|
vi.typ = validationTypeUTF8String
|
|
}
|
|
default:
|
|
switch wireTypes[fd.Kind()] {
|
|
case protowire.VarintType:
|
|
vi.typ = validationTypeVarint
|
|
case protowire.Fixed32Type:
|
|
vi.typ = validationTypeFixed32
|
|
case protowire.Fixed64Type:
|
|
vi.typ = validationTypeFixed64
|
|
case protowire.BytesType:
|
|
vi.typ = validationTypeBytes
|
|
}
|
|
}
|
|
}
|
|
return vi
|
|
}
|
|
|
|
func (mi *MessageInfo) validate(b []byte, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
|
|
mi.init()
|
|
type validationState struct {
|
|
typ validationType
|
|
keyType, valType validationType
|
|
endGroup protowire.Number
|
|
mi *MessageInfo
|
|
tail []byte
|
|
requiredMask uint64
|
|
}
|
|
|
|
// Pre-allocate some slots to avoid repeated slice reallocation.
|
|
states := make([]validationState, 0, 16)
|
|
states = append(states, validationState{
|
|
typ: validationTypeMessage,
|
|
mi: mi,
|
|
})
|
|
if groupTag > 0 {
|
|
states[0].typ = validationTypeGroup
|
|
states[0].endGroup = groupTag
|
|
}
|
|
initialized := true
|
|
start := len(b)
|
|
State:
|
|
for len(states) > 0 {
|
|
st := &states[len(states)-1]
|
|
for len(b) > 0 {
|
|
// Parse the tag (field number and wire type).
|
|
var tag uint64
|
|
if b[0] < 0x80 {
|
|
tag = uint64(b[0])
|
|
b = b[1:]
|
|
} else if len(b) >= 2 && b[1] < 128 {
|
|
tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
|
|
b = b[2:]
|
|
} else {
|
|
var n int
|
|
tag, n = protowire.ConsumeVarint(b)
|
|
if n < 0 {
|
|
return out, ValidationInvalid
|
|
}
|
|
b = b[n:]
|
|
}
|
|
var num protowire.Number
|
|
if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
|
|
return out, ValidationInvalid
|
|
} else {
|
|
num = protowire.Number(n)
|
|
}
|
|
wtyp := protowire.Type(tag & 7)
|
|
|
|
if wtyp == protowire.EndGroupType {
|
|
if st.endGroup == num {
|
|
goto PopState
|
|
}
|
|
return out, ValidationInvalid
|
|
}
|
|
var vi validationInfo
|
|
switch {
|
|
case st.typ == validationTypeMap:
|
|
switch num {
|
|
case genid.MapEntry_Key_field_number:
|
|
vi.typ = st.keyType
|
|
case genid.MapEntry_Value_field_number:
|
|
vi.typ = st.valType
|
|
vi.mi = st.mi
|
|
vi.requiredBit = 1
|
|
}
|
|
case flags.ProtoLegacy && st.mi.isMessageSet:
|
|
switch num {
|
|
case messageset.FieldItem:
|
|
vi.typ = validationTypeMessageSetItem
|
|
}
|
|
default:
|
|
var f *coderFieldInfo
|
|
if int(num) < len(st.mi.denseCoderFields) {
|
|
f = st.mi.denseCoderFields[num]
|
|
} else {
|
|
f = st.mi.coderFields[num]
|
|
}
|
|
if f != nil {
|
|
vi = f.validation
|
|
if vi.typ == validationTypeMessage && vi.mi == nil {
|
|
// Probable weak field.
|
|
//
|
|
// TODO: Consider storing the results of this lookup somewhere
|
|
// rather than recomputing it on every validation.
|
|
fd := st.mi.Desc.Fields().ByNumber(num)
|
|
if fd == nil || !fd.IsWeak() {
|
|
break
|
|
}
|
|
messageName := fd.Message().FullName()
|
|
messageType, err := protoregistry.GlobalTypes.FindMessageByName(messageName)
|
|
switch err {
|
|
case nil:
|
|
vi.mi, _ = messageType.(*MessageInfo)
|
|
case protoregistry.NotFound:
|
|
vi.typ = validationTypeBytes
|
|
default:
|
|
return out, ValidationUnknown
|
|
}
|
|
}
|
|
break
|
|
}
|
|
// Possible extension field.
|
|
//
|
|
// TODO: We should return ValidationUnknown when:
|
|
// 1. The resolver is not frozen. (More extensions may be added to it.)
|
|
// 2. The resolver returns preg.NotFound.
|
|
// In this case, a type added to the resolver in the future could cause
|
|
// unmarshaling to begin failing. Supporting this requires some way to
|
|
// determine if the resolver is frozen.
|
|
xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
|
|
if err != nil && err != protoregistry.NotFound {
|
|
return out, ValidationUnknown
|
|
}
|
|
if err == nil {
|
|
vi = getExtensionFieldInfo(xt).validation
|
|
}
|
|
}
|
|
if vi.requiredBit != 0 {
|
|
// Check that the field has a compatible wire type.
|
|
// We only need to consider non-repeated field types,
|
|
// since repeated fields (and maps) can never be required.
|
|
ok := false
|
|
switch vi.typ {
|
|
case validationTypeVarint:
|
|
ok = wtyp == protowire.VarintType
|
|
case validationTypeFixed32:
|
|
ok = wtyp == protowire.Fixed32Type
|
|
case validationTypeFixed64:
|
|
ok = wtyp == protowire.Fixed64Type
|
|
case validationTypeBytes, validationTypeUTF8String, validationTypeMessage:
|
|
ok = wtyp == protowire.BytesType
|
|
case validationTypeGroup:
|
|
ok = wtyp == protowire.StartGroupType
|
|
}
|
|
if ok {
|
|
st.requiredMask |= vi.requiredBit
|
|
}
|
|
}
|
|
|
|
switch wtyp {
|
|
case protowire.VarintType:
|
|
if len(b) >= 10 {
|
|
switch {
|
|
case b[0] < 0x80:
|
|
b = b[1:]
|
|
case b[1] < 0x80:
|
|
b = b[2:]
|
|
case b[2] < 0x80:
|
|
b = b[3:]
|
|
case b[3] < 0x80:
|
|
b = b[4:]
|
|
case b[4] < 0x80:
|
|
b = b[5:]
|
|
case b[5] < 0x80:
|
|
b = b[6:]
|
|
case b[6] < 0x80:
|
|
b = b[7:]
|
|
case b[7] < 0x80:
|
|
b = b[8:]
|
|
case b[8] < 0x80:
|
|
b = b[9:]
|
|
case b[9] < 0x80 && b[9] < 2:
|
|
b = b[10:]
|
|
default:
|
|
return out, ValidationInvalid
|
|
}
|
|
} else {
|
|
switch {
|
|
case len(b) > 0 && b[0] < 0x80:
|
|
b = b[1:]
|
|
case len(b) > 1 && b[1] < 0x80:
|
|
b = b[2:]
|
|
case len(b) > 2 && b[2] < 0x80:
|
|
b = b[3:]
|
|
case len(b) > 3 && b[3] < 0x80:
|
|
b = b[4:]
|
|
case len(b) > 4 && b[4] < 0x80:
|
|
b = b[5:]
|
|
case len(b) > 5 && b[5] < 0x80:
|
|
b = b[6:]
|
|
case len(b) > 6 && b[6] < 0x80:
|
|
b = b[7:]
|
|
case len(b) > 7 && b[7] < 0x80:
|
|
b = b[8:]
|
|
case len(b) > 8 && b[8] < 0x80:
|
|
b = b[9:]
|
|
case len(b) > 9 && b[9] < 2:
|
|
b = b[10:]
|
|
default:
|
|
return out, ValidationInvalid
|
|
}
|
|
}
|
|
continue State
|
|
case protowire.BytesType:
|
|
var size uint64
|
|
if len(b) >= 1 && b[0] < 0x80 {
|
|
size = uint64(b[0])
|
|
b = b[1:]
|
|
} else if len(b) >= 2 && b[1] < 128 {
|
|
size = uint64(b[0]&0x7f) + uint64(b[1])<<7
|
|
b = b[2:]
|
|
} else {
|
|
var n int
|
|
size, n = protowire.ConsumeVarint(b)
|
|
if n < 0 {
|
|
return out, ValidationInvalid
|
|
}
|
|
b = b[n:]
|
|
}
|
|
if size > uint64(len(b)) {
|
|
return out, ValidationInvalid
|
|
}
|
|
v := b[:size]
|
|
b = b[size:]
|
|
switch vi.typ {
|
|
case validationTypeMessage:
|
|
if vi.mi == nil {
|
|
return out, ValidationUnknown
|
|
}
|
|
vi.mi.init()
|
|
fallthrough
|
|
case validationTypeMap:
|
|
if vi.mi != nil {
|
|
vi.mi.init()
|
|
}
|
|
states = append(states, validationState{
|
|
typ: vi.typ,
|
|
keyType: vi.keyType,
|
|
valType: vi.valType,
|
|
mi: vi.mi,
|
|
tail: b,
|
|
})
|
|
b = v
|
|
continue State
|
|
case validationTypeRepeatedVarint:
|
|
// Packed field.
|
|
for len(v) > 0 {
|
|
_, n := protowire.ConsumeVarint(v)
|
|
if n < 0 {
|
|
return out, ValidationInvalid
|
|
}
|
|
v = v[n:]
|
|
}
|
|
case validationTypeRepeatedFixed32:
|
|
// Packed field.
|
|
if len(v)%4 != 0 {
|
|
return out, ValidationInvalid
|
|
}
|
|
case validationTypeRepeatedFixed64:
|
|
// Packed field.
|
|
if len(v)%8 != 0 {
|
|
return out, ValidationInvalid
|
|
}
|
|
case validationTypeUTF8String:
|
|
if !utf8.Valid(v) {
|
|
return out, ValidationInvalid
|
|
}
|
|
}
|
|
case protowire.Fixed32Type:
|
|
if len(b) < 4 {
|
|
return out, ValidationInvalid
|
|
}
|
|
b = b[4:]
|
|
case protowire.Fixed64Type:
|
|
if len(b) < 8 {
|
|
return out, ValidationInvalid
|
|
}
|
|
b = b[8:]
|
|
case protowire.StartGroupType:
|
|
switch {
|
|
case vi.typ == validationTypeGroup:
|
|
if vi.mi == nil {
|
|
return out, ValidationUnknown
|
|
}
|
|
vi.mi.init()
|
|
states = append(states, validationState{
|
|
typ: validationTypeGroup,
|
|
mi: vi.mi,
|
|
endGroup: num,
|
|
})
|
|
continue State
|
|
case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
|
|
typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
|
|
if err != nil {
|
|
return out, ValidationInvalid
|
|
}
|
|
xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
|
|
switch {
|
|
case err == protoregistry.NotFound:
|
|
b = b[n:]
|
|
case err != nil:
|
|
return out, ValidationUnknown
|
|
default:
|
|
xvi := getExtensionFieldInfo(xt).validation
|
|
if xvi.mi != nil {
|
|
xvi.mi.init()
|
|
}
|
|
states = append(states, validationState{
|
|
typ: xvi.typ,
|
|
mi: xvi.mi,
|
|
tail: b[n:],
|
|
})
|
|
b = v
|
|
continue State
|
|
}
|
|
default:
|
|
n := protowire.ConsumeFieldValue(num, wtyp, b)
|
|
if n < 0 {
|
|
return out, ValidationInvalid
|
|
}
|
|
b = b[n:]
|
|
}
|
|
default:
|
|
return out, ValidationInvalid
|
|
}
|
|
}
|
|
if st.endGroup != 0 {
|
|
return out, ValidationInvalid
|
|
}
|
|
if len(b) != 0 {
|
|
return out, ValidationInvalid
|
|
}
|
|
b = st.tail
|
|
PopState:
|
|
numRequiredFields := 0
|
|
switch st.typ {
|
|
case validationTypeMessage, validationTypeGroup:
|
|
numRequiredFields = int(st.mi.numRequiredFields)
|
|
case validationTypeMap:
|
|
// If this is a map field with a message value that contains
|
|
// required fields, require that the value be present.
|
|
if st.mi != nil && st.mi.numRequiredFields > 0 {
|
|
numRequiredFields = 1
|
|
}
|
|
}
|
|
// If there are more than 64 required fields, this check will
|
|
// always fail and we will report that the message is potentially
|
|
// uninitialized.
|
|
if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
|
|
initialized = false
|
|
}
|
|
states = states[:len(states)-1]
|
|
}
|
|
out.n = start - len(b)
|
|
if initialized {
|
|
out.initialized = true
|
|
}
|
|
return out, ValidationValid
|
|
}
|