Make 'where' template function accepts dot chaining key argument

'where' template function used to accept only each element's struct
field name, method name and map key name as its second argument. This
extends it to accept dot chaining key like 'Params.foo.bar' as the
argument. It evaluates sub elements of each array elements and checks it
matches the third argument value.

Typical use case would be for filtering Pages by user defined front
matter value. For example, to filter pages which have 'Params.foo.bar'
and its value is 'baz', it is used like

    {{ range where .Data.Pages "Params.foo.bar" "baz" }}
        {{ .Content }}
    {{ end }}

It ignores all leading and trailing dots so it can also be used with
".Params.foo.bar"
This commit is contained in:
Tatsushi Demachi
2014-12-29 11:33:12 +09:00
committed by bep
parent dd5bc0345b
commit fa8ac87d5e
3 changed files with 227 additions and 56 deletions

View File

@@ -289,6 +289,19 @@ func In(l interface{}, v interface{}) bool {
return false
}
// indirect is taken from 'text/template/exec.go'
func indirect(v reflect.Value) (rv reflect.Value, isNil bool) {
for ; v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface; v = v.Elem() {
if v.IsNil() {
return v, true
}
if v.Kind() == reflect.Interface && v.NumMethod() > 0 {
break
}
}
return v, false
}
// First is exposed to templates, to iterate over the first N items in a
// rangeable list.
func First(limit interface{}, seq interface{}) (interface{}, error) {
@@ -326,76 +339,122 @@ func First(limit interface{}, seq interface{}) (interface{}, error) {
return seqv.Slice(0, limitv).Interface(), nil
}
func Where(seq, key, match interface{}) (interface{}, error) {
var (
zero reflect.Value
errorType = reflect.TypeOf((*error)(nil)).Elem()
)
func evaluateSubElem(obj reflect.Value, elemName string) (reflect.Value, error) {
if !obj.IsValid() {
return zero, errors.New("can't evaluate an invalid value")
}
typ := obj.Type()
obj, isNil := indirect(obj)
// first, check whether obj has a method. In this case, obj is
// an interface, a struct or its pointer. If obj is a struct,
// to check all T and *T method, use obj pointer type Value
objPtr := obj
if objPtr.Kind() != reflect.Interface && objPtr.CanAddr() {
objPtr = objPtr.Addr()
}
mt, ok := objPtr.Type().MethodByName(elemName)
if ok {
if mt.PkgPath != "" {
return zero, fmt.Errorf("%s is an unexported method of type %s", elemName, typ)
}
// struct pointer has one receiver argument and interface doesn't have an argument
if mt.Type.NumIn() > 1 || mt.Type.NumOut() == 0 || mt.Type.NumOut() > 2 {
return zero, fmt.Errorf("%s is a method of type %s but doesn't satisfy requirements", elemName, typ)
}
if mt.Type.NumOut() == 1 && mt.Type.Out(0).Implements(errorType) {
return zero, fmt.Errorf("%s is a method of type %s but doesn't satisfy requirements", elemName, typ)
}
if mt.Type.NumOut() == 2 && !mt.Type.Out(1).Implements(errorType) {
return zero, fmt.Errorf("%s is a method of type %s but doesn't satisfy requirements", elemName, typ)
}
res := objPtr.Method(mt.Index).Call([]reflect.Value{})
if len(res) == 2 && !res[1].IsNil() {
return zero, fmt.Errorf("error at calling a method %s of type %s: %s", elemName, typ, res[1].Interface().(error))
}
return res[0], nil
}
// elemName isn't a method so next start to check whether it is
// a struct field or a map value. In both cases, it mustn't be
// a nil value
if isNil {
return zero, fmt.Errorf("can't evaluate a nil pointer of type %s by a struct field or map key name %s", typ, elemName)
}
switch obj.Kind() {
case reflect.Struct:
ft, ok := obj.Type().FieldByName(elemName)
if ok {
if ft.PkgPath != "" {
return zero, fmt.Errorf("%s is an unexported field of struct type %s", elemName, typ)
}
return obj.FieldByIndex(ft.Index), nil
}
return zero, fmt.Errorf("%s isn't a field of struct type %s", elemName, typ)
case reflect.Map:
kv := reflect.ValueOf(elemName)
if kv.Type().AssignableTo(obj.Type().Key()) {
return obj.MapIndex(kv), nil
}
return zero, fmt.Errorf("%s isn't a key of map type %s", elemName, typ)
}
return zero, fmt.Errorf("%s is neither a struct field, a method nor a map element of type %s", elemName, typ)
}
func Where(seq, key, match interface{}) (r interface{}, err error) {
seqv := reflect.ValueOf(seq)
kv := reflect.ValueOf(key)
mv := reflect.ValueOf(match)
// this is better than my first pass; ripped from text/template/exec.go indirect():
for ; seqv.Kind() == reflect.Ptr || seqv.Kind() == reflect.Interface; seqv = seqv.Elem() {
if seqv.IsNil() {
return nil, errors.New("can't iterate over a nil value")
}
if seqv.Kind() == reflect.Interface && seqv.NumMethod() > 0 {
break
}
seqv, isNil := indirect(seqv)
if isNil {
return nil, errors.New("can't iterate over a nil value of type " + reflect.ValueOf(seq).Type().String())
}
var path []string
if kv.Kind() == reflect.String {
path = strings.Split(strings.Trim(kv.String(), "."), ".")
}
switch seqv.Kind() {
case reflect.Array, reflect.Slice:
r := reflect.MakeSlice(seqv.Type(), 0, 0)
rv := reflect.MakeSlice(seqv.Type(), 0, 0)
for i := 0; i < seqv.Len(); i++ {
var vvv reflect.Value
vv := seqv.Index(i)
switch vv.Kind() {
case reflect.Map:
if kv.Type() == vv.Type().Key() && vv.MapIndex(kv).IsValid() {
rvv := seqv.Index(i)
if kv.Kind() == reflect.String {
vvv = rvv
for _, elemName := range path {
vvv, err = evaluateSubElem(vvv, elemName)
if err != nil {
return nil, err
}
}
} else {
vv, _ := indirect(rvv)
if vv.Kind() == reflect.Map && kv.Type().AssignableTo(vv.Type().Key()) {
vvv = vv.MapIndex(kv)
}
case reflect.Struct:
if kv.Kind() == reflect.String {
method := vv.MethodByName(kv.String())
if method.IsValid() && method.Type().NumIn() == 0 && method.Type().NumOut() > 0 {
vvv = method.Call(nil)[0]
} else if vv.FieldByName(kv.String()).IsValid() {
vvv = vv.FieldByName(kv.String())
}
}
case reflect.Ptr:
if !vv.IsNil() {
ev := vv.Elem()
switch ev.Kind() {
case reflect.Map:
if kv.Type() == ev.Type().Key() && ev.MapIndex(kv).IsValid() {
vvv = ev.MapIndex(kv)
}
case reflect.Struct:
if kv.Kind() == reflect.String {
method := vv.MethodByName(kv.String())
if method.IsValid() && method.Type().NumIn() == 0 && method.Type().NumOut() > 0 {
vvv = method.Call(nil)[0]
} else if ev.FieldByName(kv.String()).IsValid() {
vvv = ev.FieldByName(kv.String())
}
}
}
}
}
if vvv.IsValid() && mv.Type() == vvv.Type() {
switch mv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if mv.Int() == vvv.Int() {
r = reflect.Append(r, vv)
rv = reflect.Append(rv, rvv)
}
case reflect.String:
if mv.String() == vvv.String() {
r = reflect.Append(r, vv)
rv = reflect.Append(rv, rvv)
}
}
}
}
return r.Interface(), nil
return rv.Interface(), nil
default:
return nil, errors.New("can't iterate over " + reflect.ValueOf(seq).Type().String())
}