gin中validator模块的源码分析
# 简介
在gin中使用的是validator模块来对表单进行校验的。
validator模块github地址 (opens new window)
# 懒加载validate对象
众所周知,在api层需要使用gin.Context中的ShouldBindJSON方法来对request中的json字段进行校验,例子如下:
func login(c *gin.Context) {
var user User
if err := c.ShouldBindJSON(&user); err != nil {
errs, ok := err.(validator.ValidationErrors)
if !ok {
// 非校验错误,其他错误直接返回
c.JSON(http.StatusOK, gin.H{"msg": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"msg": errs.Translate(global.Trans)})
return
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
在c.ShourdBindJSON中,一直往下跳转,跟到路径binding/default_validator.go
下,可以看到代码如下,对表单的校验是调用v.validate.Struct(obj)
来完成的,这里的validate是validator库中Validate对象,所以在之前使用v.lazyinit()
来创建该对象。
// validateStruct receives struct type
func (v *defaultValidator) validateStruct(obj any) error {
v.lazyinit()
return v.validate.Struct(obj)
}
2
3
4
5
再看v.lazyinit()
方法,这里使用了once.Do方法,该once是sync.Onec对象,Do中的方法只会执行一次,也就是只会创建一次Validate对象,保持该对象的单例。
type defaultValidator struct {
once sync.Once
validate *validator.Validate
}
func (v *defaultValidator) lazyinit() {
v.once.Do(func() {
v.validate = validator.New()
v.validate.SetTagName("binding")
})
}
2
3
4
5
6
7
8
9
10
11
那么有这么一种场景,如果我想要对validator中的validate对象进行配置,该怎么办呢?因为上面都是在使用时懒加载才加载的,我们需要提前拿到validate对象并进行配置,该如何处理?
在代码binding/default_validator.go
中,提供了Engine方法,我们可以使用该方法给使用这调用提前加载validate对象,并返回该对象,就可以进行配置了。
func (v *defaultValidator) Engine() any {
v.lazyinit()
return v.validate
}
2
3
4
结论:
结构体中的对象可以使用懒加载的方式,在使用的时候再进行创建。
可以使用once.Do的方法创建对象保持单例。
可以提供方法提前加载对象,然后返回出来进行使用配置
# 多种请求参数的校验
我们使用c.ShouldBindJSON(&user)
可以对json格式的请求参数进行校验,也可以使用c.ShouldBindXML
对xml格式的请求参数进行校验。
在代码binding/binding.go中可以看到定义了多种对request进行校验的全局对象。
var (
JSON = jsonBinding{}
XML = xmlBinding{}
Form = formBinding{}
Query = queryBinding{}
FormPost = formPostBinding{}
FormMultipart = formMultipartBinding{}
ProtoBuf = protobufBinding{}
MsgPack = msgpackBinding{}
YAML = yamlBinding{}
Uri = uriBinding{}
Header = headerBinding{}
TOML = tomlBinding{}
)
2
3
4
5
6
7
8
9
10
11
12
13
14
15
进入c.ShouldBindJSON
方法,可以看到调用了c.ShouldBindWith
方法并将表单对象和全局的JSON对象传入其中。
// ShouldBindJSON is a shortcut for c.ShouldBindWith(obj, binding.JSON).
func (c *Context) ShouldBindJSON(obj any) error {
return c.ShouldBindWith(obj, binding.JSON)
}
2
3
4
在c.ShouldBindWith
中,是直接使用上面的JSON全局对象的bind方法,将requets和表单对象传入其中。
// ShouldBindWith binds the passed struct pointer using the specified binding engine.
// See the binding package.
func (c *Context) ShouldBindWith(obj any, b binding.Binding) error {
return b.Bind(c.Request, obj)
}
2
3
4
5
再看b.Bind方法,你会发现,它做的事情是:校验请求参数是否为json格式,然后再调用validate方法,该方法就是去创建go-palyground模块中Validate对象,然后调用其struct方法进行参数验证。
func (jsonBinding) Bind(req *http.Request, obj any) error {
if req == nil || req.Body == nil {
return errors.New("invalid request")
}
return decodeJSON(req.Body, obj)
}
func decodeJSON(r io.Reader, obj any) error {
decoder := json.NewDecoder(r)
if EnableDecoderUseNumber {
decoder.UseNumber()
}
if EnableDecoderDisallowUnknownFields {
decoder.DisallowUnknownFields()
}
if err := decoder.Decode(obj); err != nil {
return err
}
return validate(obj)
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
你再看看ShouldBindXML方法,就会发现是和上面一样的过程,调用的是全局对象XML中的bind方法,然后再校验请求参数是否为xml格式,最终调用validate方法。
总结:
在binding/binding.go
中的全局对象都实现了一个bind方法,该方法是对请求参数的格式进行检验,然后最终会调用validate方法,去创建go-palyground模块中Validate对象,再调用其struct对象进行请求参数的内容校验。
创建多个ShouldBindXXX,在不同的方法中使用不同的全局对象,这些对象都有一个相同的方法bind,然后再统一调用bind方法进行校验。
# 钩子方法
validator库中Validate结构体提供了一系列的钩子方法,在校验中的过程中,提供给使用者来修改其中的部分内容。
我们看一下validator_instance.go
文件中Validate结构体,该结构体中包含了hasTagNameFunc
和tagNameFunc
两个变量,tagNameFunc
代表着钩子方法,使用者可以自定方法来赋值给它,hasTagNameFunc
是bool类型代表是否配置了钩子方法.
type Validate struct {
tagName string
pool *sync.Pool
hasCustomFuncs bool
hasTagNameFunc bool
tagNameFunc TagNameFunc
structLevelFuncs map[reflect.Type]StructLevelFuncCtx
customFuncs map[reflect.Type]CustomTypeFunc
aliases map[string]string
validations map[string]internalValidationFuncWrapper
transTagFunc map[ut.Translator]map[string]TranslationFunc // map[<locale>]map[<tag>]TranslationFunc
tagCache *tagCache
structCache *structCache
}
2
3
4
5
6
7
8
9
10
11
12
13
14
使用者可以调用RegisterTagNameFunc
方法来注册自己写的方法
func (v *Validate) RegisterTagNameFunc(fn TagNameFunc) {
v.tagNameFunc = fn
v.hasTagNameFunc = true
}
2
3
4
在代码cache.go
中extractStructCache方法中有下面一段代码,通过调用hasTagNameFunc方法来判断是否设置了钩子方法,如果设置了,则调用tagNameFunc来执行使用者自定的方法。
if v.hasTagNameFunc {
name := v.tagNameFunc(fld)
if len(name) > 0 {
customName = name
}
}
2
3
4
5
6
总结:
- 可以提供钩子方法暴露给使用者参与到执行过程中来。
# 对象池的应用
看文件validator_instance.go
中的Struct方法,这个方法就是表单的校验的入口方法,可以看到它又调用了StructCtx方法。
func (v *Validate) Struct(s interface{}) error {
return v.StructCtx(context.Background(), s)
}
2
3
再来看看StructCtx方法,其中pool是*sync.Pool类型,调用Get方法获取validate对象,然后再调用其validateStruct方法,其中会进行请求参数的校验,然后如果校验有误,会将错误信息保存在errs中。valudate使用完成后,再put回pool中。
因为每一次校验都需要创建validate对象,所以这里使用了sync.Pool可以复用临时对象,减少内存的分配,降低GC压力。
func (v *Validate) StructCtx(ctx context.Context, s interface{}) (err error) {
val := reflect.ValueOf(s)
top := val
if val.Kind() == reflect.Ptr && !val.IsNil() {
val = val.Elem()
}
if val.Kind() != reflect.Struct || val.Type() == timeType {
return &InvalidValidationError{Type: reflect.TypeOf(s)}
}
// good to validate
vd := v.pool.Get().(*validate)
vd.top = top
vd.isPartial = false
// vd.hasExcludes = false // only need to reset in StructPartial and StructExcept
vd.validateStruct(ctx, top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil)
if len(vd.errs) > 0 {
err = vd.errs
vd.errs = nil
}
v.pool.Put(vd)
return
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
总结:
- 当需要频繁创建相同对象,并且该对象是无状态时,可以使用sync.Pool来提高复用临时对象,减少内存的分配,降低GC压力。
# 根据标签校验过程
在懒加载创建validator的Validate对象时,是调用validator_instance.go
中New方法来创建该对象,在该方法中初始化所有标签及标签对应的校验方法并保存在Validate对象中的validations变量中
for k, val := range bakedInValidators {
switch k {
// these require that even if the value is nil that the validation should run, omitempty still overrides this behaviour
case requiredIfTag, requiredUnlessTag, requiredWithTag, requiredWithAllTag, requiredWithoutTag, requiredWithoutAllTag,
excludedWithTag, excludedWithAllTag, excludedWithoutTag, excludedWithoutAllTag:
_ = v.registerValidation(k, wrapFunc(val), true, true)
default:
// no need to error check here, baked in will always be valid
_ = v.registerValidation(k, wrapFunc(val), true, false)
}
}
func (v *Validate) registerValidation(tag string, fn FuncCtx, bakedIn bool, nilCheckable bool) error {
if len(tag) == 0 {
return errors.New("function Key cannot be empty")
}
if fn == nil {
return errors.New("function cannot be empty")
}
_, ok := restrictedTags[tag]
if !bakedIn && (ok || strings.ContainsAny(tag, restrictedTagChars)) {
panic(fmt.Sprintf(restrictedTagErr, tag))
}
v.validations[tag] = internalValidationFuncWrapper{fn: fn, runValidatinOnNil: nilCheckable}
return nil
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
我们可以在baked_in.go
文件中看到有定义以标签为key,校验方法为value的map对象。
bakedInValidators = map[string]Func{
"required": hasValue,
"required_if": requiredIf,
"required_unless": requiredUnless,
"required_with": requiredWith,
"required_with_all": requiredWithAll,
"required_without": requiredWithout,
"required_without_all": requiredWithoutAll,
......
2
3
4
5
6
7
8
9
在代码cache.go
中extractStructCache方法中,会遍历请求参数的每一个字段,然后根据该字段的tag创建对应的ctag对象,再创建该字段的cField对象,并将ctag传入。
if len(tag) > 0 {
ctag, _ = v.parseFieldTagsRecursive(tag, fld.Name, "", false)
} else {
// even if field doesn't have validations need cTag for traversing to potential inner/nested
// elements of the field.
ctag = new(cTag)
}
cs.fields = append(cs.fields, &cField{
idx: i,
name: fld.Name,
altName: customName,
cTags: ctag,
namesEqual: fld.Name == customName,
})
2
3
4
5
6
7
8
9
10
11
12
13
14
15
在上面的v.parseFieldTagsRecursive
方法中会将从之前缓存的validatetions集合中找到该ctag的校验方法,然后赋值ctag.fn方法。
if wrapper, ok := v.validations[current.tag]; ok {
current.fn = wrapper.fn
current.runValidationWhenNil = wrapper.runValidatinOnNil
} else {
panic(strings.TrimSpace(fmt.Sprintf(undefinedValidation, current.tag, fieldName)))
}
2
3
4
5
6
在文件validator.go中traverseField方法校验字段中会调用上面传递的fn方法,返回值为false时,会创建fieldError对象保存在validate的errs中
if !ct.fn(ctx, v) {
v.str1 = string(append(ns, cf.altName...))
if v.v.hasTagNameFunc {
v.str2 = string(append(structNs, cf.name...))
} else {
v.str2 = v.str1
}
v.errs = append(v.errs,
&fieldError{
v: v.v,
tag: ct.aliasTag,
actualTag: ct.tag,
ns: v.str1,
structNs: v.str2,
fieldLen: uint8(len(cf.altName)),
structfieldLen: uint8(len(cf.name)),
value: current.Interface(),
param: ct.param,
kind: kind,
typ: typ,
},
)
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
最终将这个error在validator_instance.go中的StructCtx方法中返回出去。
vd.validateStruct(ctx, top, val, val.Type(), vd.ns[0:0], vd.actualNs[0:0], nil)
if len(vd.errs) > 0 {
err = vd.errs
vd.errs = nil
}
2
3
4
5
6
结论:
error并不是调用validateStruct返回出来的,而是通过在结构体中创建了一个error,在过程中如果发生了错误则将该错误信息保存在结构体的error变量中,也是一种error的返回方式。
先初始化所有的校验方法,然后再根据每个字段选择校验方法进行执行。逻辑非常清晰。
# 错误提示信息翻译
从上面的过程可以发现,仅能够发现哪个字段再哪个tag中发生了错误,如下所示:
{'msg': "Key: 'User.password' Error:Field validation for 'password' failed on the 'required' tag"}
但是我们需要的时候人性化的错误提示。
在代码translations/zh/zh.go
的RegisterDefaultTranslations方法中可以看到定义了每个tag及对应中文提示信息。这些信息会注册到validator的Validate.transTagFunc变量中。
func RegisterDefaultTranslations(v *validator.Validate, trans ut.Translator) (err error) {
translations := []struct {
tag string
translation string
override bool
customRegisFunc validator.RegisterTranslationsFunc
customTransFunc validator.TranslationFunc
}{
{
tag: "required",
translation: "{0}为必填字段",
override: false,
},
2
3
4
5
6
7
8
9
10
11
12
13
14
每个字段的错误对象都有Translate方法,当调用方法时,会遍历当前所有错误的字段,然后再找到每个Tag对应的提示信息输出。
func (ve ValidationErrors) Translate(ut ut.Translator) ValidationErrorsTranslations {
trans := make(ValidationErrorsTranslations)
var fe *fieldError
for i := 0; i < len(ve); i++ {
fe = ve[i].(*fieldError)
// // in case an Anonymous struct was used, ensure that the key
// // would be 'Username' instead of ".Username"
// if len(fe.ns) > 0 && fe.ns[:1] == "." {
// trans[fe.ns[1:]] = fe.Translate(ut)
// continue
// }
trans[fe.ns] = fe.Translate(ut)
}
return trans
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
这个是每一个字段通过tag获取提示信息。
func (fe *fieldError) Translate(ut ut.Translator) string {
m, ok := fe.v.transTagFunc[ut]
if !ok {
return fe.Error()
}
fn, ok := m[fe.tag]
if !ok {
return fe.Error()
}
return fn(ut, fe)
}
2
3
4
5
6
7
8
9
10
11
12
13
14