diff --git a/hugolib/template.go b/hugolib/template.go index 664bb5214..0d8c29bf4 100644 --- a/hugolib/template.go +++ b/hugolib/template.go @@ -14,6 +14,7 @@ import ( "strings" "github.com/eknkc/amber" + "github.com/spf13/cast" "github.com/spf13/hugo/helpers" jww "github.com/spf13/jwalterweatherman" ) @@ -167,8 +168,15 @@ func In(l interface{}, v interface{}) bool { // First is exposed to templates, to iterate over the first N items in a // rangeable list. -func First(limit int, seq interface{}) (interface{}, error) { - if limit < 1 { +func First(limit interface{}, seq interface{}) (interface{}, error) { + + limitv, err := cast.ToIntE(limit) + + if err != nil { + return nil, err + } + + if limitv < 1 { return nil, errors.New("can't return negative/empty count of items from sequence") } @@ -189,10 +197,10 @@ func First(limit int, seq interface{}) (interface{}, error) { default: return nil, errors.New("can't iterate over " + reflect.ValueOf(seq).Type().String()) } - if limit > seqv.Len() { - limit = seqv.Len() + if limitv > seqv.Len() { + limitv = seqv.Len() } - return seqv.Slice(0, limit).Interface(), nil + return seqv.Slice(0, limitv).Interface(), nil } func Where(seq, key, match interface{}) (interface{}, error) { diff --git a/hugolib/template_test.go b/hugolib/template_test.go index eb0a42707..b4d95f0b4 100644 --- a/hugolib/template_test.go +++ b/hugolib/template_test.go @@ -125,21 +125,31 @@ func TestDoArithmetic(t *testing.T) { func TestFirst(t *testing.T) { for i, this := range []struct { - count int + count interface{} sequence interface{} expect interface{} }{ - {2, []string{"a", "b", "c"}, []string{"a", "b"}}, - {3, []string{"a", "b"}, []string{"a", "b"}}, - {2, []int{100, 200, 300}, []int{100, 200}}, + {int(2), []string{"a", "b", "c"}, []string{"a", "b"}}, + {int32(3), []string{"a", "b"}, []string{"a", "b"}}, + {int64(2), []int{100, 200, 300}, []int{100, 200}}, + {100, []int{100, 200}, []int{100, 200}}, + {"1", []int{100, 200, 300}, []int{100}}, + {int64(-1), []int{100, 200, 300}, false}, + {"noint", []int{100, 200, 300}, false}, } { results, err := First(this.count, this.sequence) - if err != nil { - t.Errorf("[%d] failed: %s", i, err) - continue - } - if !reflect.DeepEqual(results, this.expect) { - t.Errorf("[%d] First %d items, got %v but expected %v", i, this.count, results, this.expect) + if b, ok := this.expect.(bool); ok && !b { + if err == nil { + t.Errorf("[%d] First didn't return an expected error") + } + } else { + if err != nil { + t.Errorf("[%d] failed: %s", i, err) + continue + } + if !reflect.DeepEqual(results, this.expect) { + t.Errorf("[%d] First %d items, got %v but expected %v", i, this.count, results, this.expect) + } } } }