tpl/collections: Fix WordCount (etc.) regression in Where, Sort, Delimit

Fixes #11234
This commit is contained in:
Bjørn Erik Pedersen
2023-07-11 09:48:57 +02:00
parent f650e4d751
commit 5bec50838c
7 changed files with 86 additions and 27 deletions

View File

@@ -14,6 +14,7 @@
package collections
import (
"context"
"errors"
"fmt"
"reflect"
@@ -24,7 +25,7 @@ import (
)
// Where returns a filtered subset of collection c.
func (ns *Namespace) Where(c, key any, args ...any) (any, error) {
func (ns *Namespace) Where(ctx context.Context, c, key any, args ...any) (any, error) {
seqv, isNil := indirect(reflect.ValueOf(c))
if isNil {
return nil, errors.New("can't iterate over a nil value of type " + reflect.ValueOf(c).Type().String())
@@ -35,6 +36,8 @@ func (ns *Namespace) Where(c, key any, args ...any) (any, error) {
return nil, err
}
ctxv := reflect.ValueOf(ctx)
var path []string
kv := reflect.ValueOf(key)
if kv.Kind() == reflect.String {
@@ -43,9 +46,9 @@ func (ns *Namespace) Where(c, key any, args ...any) (any, error) {
switch seqv.Kind() {
case reflect.Array, reflect.Slice:
return ns.checkWhereArray(seqv, kv, mv, path, op)
return ns.checkWhereArray(ctxv, seqv, kv, mv, path, op)
case reflect.Map:
return ns.checkWhereMap(seqv, kv, mv, path, op)
return ns.checkWhereMap(ctxv, seqv, kv, mv, path, op)
default:
return nil, fmt.Errorf("can't iterate over %v", c)
}
@@ -275,7 +278,7 @@ func (ns *Namespace) checkCondition(v, mv reflect.Value, op string) (bool, error
return false, nil
}
func evaluateSubElem(obj reflect.Value, elemName string) (reflect.Value, error) {
func evaluateSubElem(ctx, obj reflect.Value, elemName string) (reflect.Value, error) {
if !obj.IsValid() {
return zero, errors.New("can't evaluate an invalid value")
}
@@ -301,12 +304,20 @@ func evaluateSubElem(obj reflect.Value, elemName string) (reflect.Value, error)
index := hreflect.GetMethodIndexByName(objPtr.Type(), elemName)
if index != -1 {
var args []reflect.Value
mt := objPtr.Type().Method(index)
num := mt.Type.NumIn()
maxNumIn := 1
if num > 1 && mt.Type.In(1).Implements(hreflect.ContextInterface) {
args = []reflect.Value{ctx}
maxNumIn = 2
}
switch {
case mt.PkgPath != "":
return zero, fmt.Errorf("%s is an unexported method of type %s", elemName, typ)
case mt.Type.NumIn() > 1:
return zero, fmt.Errorf("%s is a method of type %s but requires more than 1 parameter", elemName, typ)
case mt.Type.NumIn() > maxNumIn:
return zero, fmt.Errorf("%s is a method of type %s but requires more than %d parameter", elemName, typ, maxNumIn)
case mt.Type.NumOut() == 0:
return zero, fmt.Errorf("%s is a method of type %s but returns no output", elemName, typ)
case mt.Type.NumOut() > 2:
@@ -316,7 +327,7 @@ func evaluateSubElem(obj reflect.Value, elemName string) (reflect.Value, error)
case mt.Type.NumOut() == 2 && !mt.Type.Out(1).Implements(errorType):
return zero, fmt.Errorf("%s is a method of type %s returning two values but the second value is not an error type", elemName, typ)
}
res := objPtr.Method(mt.Index).Call([]reflect.Value{})
res := objPtr.Method(mt.Index).Call(args)
if len(res) == 2 && !res[1].IsNil() {
return zero, fmt.Errorf("error at calling a method %s of type %s: %s", elemName, typ, res[1].Interface().(error))
}
@@ -371,7 +382,7 @@ func parseWhereArgs(args ...any) (mv reflect.Value, op string, err error) {
// checkWhereArray handles the where-matching logic when the seqv value is an
// Array or Slice.
func (ns *Namespace) checkWhereArray(seqv, kv, mv reflect.Value, path []string, op string) (any, error) {
func (ns *Namespace) checkWhereArray(ctxv, seqv, kv, mv reflect.Value, path []string, op string) (any, error) {
rv := reflect.MakeSlice(seqv.Type(), 0, 0)
for i := 0; i < seqv.Len(); i++ {
@@ -385,7 +396,7 @@ func (ns *Namespace) checkWhereArray(seqv, kv, mv reflect.Value, path []string,
vvv = rvv
for i, elemName := range path {
var err error
vvv, err = evaluateSubElem(vvv, elemName)
vvv, err = evaluateSubElem(ctxv, vvv, elemName)
if err != nil {
continue
@@ -417,14 +428,14 @@ func (ns *Namespace) checkWhereArray(seqv, kv, mv reflect.Value, path []string,
}
// checkWhereMap handles the where-matching logic when the seqv value is a Map.
func (ns *Namespace) checkWhereMap(seqv, kv, mv reflect.Value, path []string, op string) (any, error) {
func (ns *Namespace) checkWhereMap(ctxv, seqv, kv, mv reflect.Value, path []string, op string) (any, error) {
rv := reflect.MakeMap(seqv.Type())
keys := seqv.MapKeys()
for _, k := range keys {
elemv := seqv.MapIndex(k)
switch elemv.Kind() {
case reflect.Array, reflect.Slice:
r, err := ns.checkWhereArray(elemv, kv, mv, path, op)
r, err := ns.checkWhereArray(ctxv, elemv, kv, mv, path, op)
if err != nil {
return nil, err
}
@@ -443,7 +454,7 @@ func (ns *Namespace) checkWhereMap(seqv, kv, mv reflect.Value, path []string, op
switch elemvv.Kind() {
case reflect.Array, reflect.Slice:
r, err := ns.checkWhereArray(elemvv, kv, mv, path, op)
r, err := ns.checkWhereArray(ctxv, elemvv, kv, mv, path, op)
if err != nil {
return nil, err
}