From 1ba80874e415d1234a64c211aba66bbb756e48b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20Erik=20Pedersen?= Date: Tue, 26 Aug 2025 20:37:08 +0200 Subject: [PATCH] tpl/collections: Add collections.D using Vitter's Method D for sequential random sampling --- common/maps/cache.go | 27 ++++- common/maps/cache_test.go | 57 +++++++++++ go.mod | 4 +- go.sum | 4 +- tpl/collections/collections.go | 29 +++++- tpl/collections/collections_test.go | 29 ++++++ tpl/collections/vitter.go | 149 ++++++++++++++++++++++++++++ tpl/collections/vitter_test.go | 95 ++++++++++++++++++ 8 files changed, 387 insertions(+), 7 deletions(-) create mode 100644 common/maps/cache_test.go create mode 100644 tpl/collections/vitter.go create mode 100644 tpl/collections/vitter_test.go diff --git a/common/maps/cache.go b/common/maps/cache.go index de1535994..cf1307f3c 100644 --- a/common/maps/cache.go +++ b/common/maps/cache.go @@ -20,13 +20,27 @@ import ( // Cache is a simple thread safe cache backed by a map. type Cache[K comparable, T any] struct { m map[K]T + opts CacheOptions hasBeenInitialized bool sync.RWMutex } -// NewCache creates a new Cache. +// CacheOptions are the options for the Cache. +type CacheOptions struct { + // If set, the cache will not grow beyond this size. + Size uint64 +} + +var defaultCacheOptions = CacheOptions{} + +// NewCache creates a new Cache with default options. func NewCache[K comparable, T any]() *Cache[K, T] { - return &Cache[K, T]{m: make(map[K]T)} + return &Cache[K, T]{m: make(map[K]T), opts: defaultCacheOptions} +} + +// NewCacheWithOptions creates a new Cache with the given options. +func NewCacheWithOptions[K comparable, T any](opts CacheOptions) *Cache[K, T] { + return &Cache[K, T]{m: make(map[K]T), opts: opts} } // Delete deletes the given key from the cache. @@ -65,6 +79,7 @@ func (c *Cache[K, T]) GetOrCreate(key K, create func() (T, error)) (T, error) { if err != nil { return v, err } + c.clearIfNeeded() c.m[key] = v return v, nil } @@ -127,7 +142,15 @@ func (c *Cache[K, T]) SetIfAbsent(key K, value T) { } } +func (c *Cache[K, T]) clearIfNeeded() { + if c.opts.Size > 0 && uint64(len(c.m)) >= c.opts.Size { + // clear the map + clear(c.m) + } +} + func (c *Cache[K, T]) set(key K, value T) { + c.clearIfNeeded() c.m[key] = value } diff --git a/common/maps/cache_test.go b/common/maps/cache_test.go new file mode 100644 index 000000000..17e38ace8 --- /dev/null +++ b/common/maps/cache_test.go @@ -0,0 +1,57 @@ +// Copyright 2024 The Hugo Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package maps + +import ( + "testing" + + qt "github.com/frankban/quicktest" +) + +func TestCacheSize(t *testing.T) { + c := qt.New(t) + + cache := NewCacheWithOptions[string, string](CacheOptions{Size: 10}) + + for i := 0; i < 30; i++ { + cache.Set(string(rune('a'+i)), "value") + } + + c.Assert(len(cache.m), qt.Equals, 10) + + for i := 20; i < 50; i++ { + cache.GetOrCreate(string(rune('a'+i)), func() (string, error) { + return "value", nil + }) + } + + c.Assert(len(cache.m), qt.Equals, 10) + + for i := 100; i < 200; i++ { + cache.SetIfAbsent(string(rune('a'+i)), "value") + } + + c.Assert(len(cache.m), qt.Equals, 10) + + cache.InitAndGet("foo", func( + get func(key string) (string, bool), set func(key string, value string), + ) error { + for i := 50; i < 100; i++ { + set(string(rune('a'+i)), "value") + } + return nil + }) + + c.Assert(len(cache.m), qt.Equals, 10) +} diff --git a/go.mod b/go.mod index cd33942c5..34888ce67 100644 --- a/go.mod +++ b/go.mod @@ -74,7 +74,7 @@ require ( github.com/yuin/goldmark-emoji v1.0.6 go.uber.org/automaxprocs v1.5.3 gocloud.dev v0.43.0 - golang.org/x/exp v0.0.0-20221031165847-c99f073a8326 + golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b golang.org/x/image v0.30.0 golang.org/x/mod v0.27.0 golang.org/x/net v0.43.0 @@ -190,4 +190,4 @@ require ( software.sslmate.com/src/go-pkcs12 v0.2.0 // indirect ) -go 1.24 +go 1.24.0 diff --git a/go.sum b/go.sum index eb536bf4a..7b676f5c8 100644 --- a/go.sum +++ b/go.sum @@ -580,8 +580,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20221031165847-c99f073a8326 h1:QfTh0HpN6hlw6D3vu8DAwC8pBIwikq0AI1evdm+FksE= -golang.org/x/exp v0.0.0-20221031165847-c99f073a8326/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b h1:DXr+pvt3nC887026GRP39Ej11UATqWDmWuS99x26cD0= +golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20210220032944-ac19c3e999fb/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= diff --git a/tpl/collections/collections.go b/tpl/collections/collections.go index 0653a453a..5d2fc86e9 100644 --- a/tpl/collections/collections.go +++ b/tpl/collections/collections.go @@ -19,7 +19,7 @@ import ( "context" "errors" "fmt" - "math/rand" + "math/rand/v2" "reflect" "strings" "time" @@ -41,9 +41,12 @@ func New(deps *deps.Deps) *Namespace { } loc := langs.GetLocation(language) + dCache := maps.NewCacheWithOptions[dKey, []int](maps.CacheOptions{Size: 100}) + return &Namespace{ loc: loc, sortComp: compare.New(loc, true), + dCache: dCache, deps: deps, } } @@ -52,6 +55,7 @@ func New(deps *deps.Deps) *Namespace { type Namespace struct { loc *time.Location sortComp *compare.Namespace + dCache *maps.Cache[dKey, []int] deps *deps.Deps } @@ -520,6 +524,29 @@ func (ns *Namespace) Slice(args ...any) any { return collections.Slice(args...) } +type dKey struct { + seed uint64 + n int + hi int +} + +// D returns a slice of n unique random numbers in the range [0, hi) using the provded seed, +// using J. S. Vitter's Method D for sequential random sampling, from Vitter, J.S. +// - An Efficient Algorithm for Sequential Random Sampling - ACM Trans. Math. Software 11 (1985), 37-57. +// See https://getkerf.wordpress.com/2016/03/30/the-best-algorithm-no-one-knows-about/ +func (ns *Namespace) D(seed, n, hi any) []int { + key := dKey{seed: cast.ToUint64(seed), n: cast.ToInt(n), hi: cast.ToInt(hi)} + v, _ := ns.dCache.GetOrCreate(key, func() ([]int, error) { + prng := rand.New(rand.NewPCG(key.seed, 0)) + result := make([]int, 0, key.n) + _d(prng, key.n, key.hi, func(i int) { + result = append(result, i) + }) + return result, nil + }) + return v +} + type intersector struct { r reflect.Value seen map[any]bool diff --git a/tpl/collections/collections_test.go b/tpl/collections/collections_test.go index fe7f2144d..0e2d99224 100644 --- a/tpl/collections/collections_test.go +++ b/tpl/collections/collections_test.go @@ -788,6 +788,35 @@ func TestUniq(t *testing.T) { } } +func TestD(t *testing.T) { + t.Parallel() + c := qt.New(t) + ns := newNs() + + c.Assert(ns.D(42, 5, 100), qt.DeepEquals, []int{24, 34, 66, 82, 96}) + c.Assert(ns.D(31, 5, 100), qt.DeepEquals, []int{12, 37, 38, 69, 98}) +} + +func BenchmarkD2(b *testing.B) { + ns := newNs() + + runBenchmark := func(seed, n, max int) { + name := fmt.Sprintf("n=%d,max=%d", n, max) + b.Run(name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + ns.D(seed, n, max) + } + }) + } + + runBenchmark(32, 5, 100) + runBenchmark(32, 50, 1000) + runBenchmark(32, 10, 10000) + runBenchmark(32, 500, 10000) + runBenchmark(32, 10, 500000) + runBenchmark(32, 5000, 500000) +} + func (x *TstX) TstRp() string { return "r" + x.A } diff --git a/tpl/collections/vitter.go b/tpl/collections/vitter.go new file mode 100644 index 000000000..937eac0c6 --- /dev/null +++ b/tpl/collections/vitter.go @@ -0,0 +1,149 @@ +// This is just a temporary fork of https://github.com/josharian/vitter (ISC License, https://github.com/josharian/vitter/blob/main/LICENSE) +// +// This file will be removed once https://github.com/josharian/vitter/issues/1 is resolved. + +package collections + +import ( + "math" + "math/rand/v2" +) + +// https://getkerf.wordpress.com/2016/03/30/the-best-algorithm-no-one-knows-about/ + +// Copyright Kevin Lawler, released under ISC License + +// _d generates an in-order uniform random sample of size 'want' from the range [0, max) using the provided PRNG. +// +// Parameters: +// - prng: random number generator +// - want: number of samples to select +// - max: upper bound of the range [0, max) from which to sample +// - choose: callback function invoked with each selected index in ascending order +// +// If the parameters are invalid (want < 0 or want > max), no samples are selected. +// +// Vitter, J.S. - An Efficient Algorithm for Sequential Random Sampling - ACM Trans. Math. Software 11 (1985), 37-57. +func _d(prng *rand.Rand, want, max int, choose func(n int)) { + if want <= 0 || want > max { + return + } + // POTENTIAL_OPTIMIZATION_POINT: Christian Neukirchen points out we can replace exp(log(x)*y) by pow(x,y) + // POTENTIAL_OPTIMIZATION_POINT: Vitter paper points out an exponentially distributed random var can provide speed ups + // 'a' is space allocated for the hand + // 'n' is the size of the hand + // 'N' is the upper bound on the random card values + j := -1 + qu1 := -want + 1 + max + const negalphainv = -13 // threshold parameter from Vitter's paper for algorithm selection + threshold := -negalphainv * want + + wantf := float64(want) + maxf := float64(max) + ninv := 1.0 / wantf + var nmin1inv float64 + Vprime := math.Exp(math.Log(prng.Float64()) * ninv) + + qu1real := -wantf + 1.0 + maxf + var U, X, y1, y2, top, bottom, negSreal float64 + + for want > 1 && threshold < max { + var S int + + nmin1inv = 1.0 / (-1.0 + wantf) + + for { + for { + X = maxf * (-Vprime + 1.0) + S = int(math.Floor(X)) + + if S < qu1 { + break + } + + Vprime = math.Exp(math.Log(prng.Float64()) * ninv) + } + + U = prng.Float64() + negSreal = float64(-S) + y1 = math.Exp(math.Log(U*maxf/qu1real) * nmin1inv) + Vprime = y1 * (-X/maxf + 1.0) * (qu1real / (negSreal + qu1real)) + + if Vprime <= 1.0 { + break + } + + y2 = 1.0 + top = -1.0 + maxf + var limit int + + if -1+want > S { + bottom = -wantf + maxf + limit = -S + max + } else { + bottom = -1.0 + negSreal + maxf + limit = qu1 + } + + for t := max - 1; t >= limit; t-- { + y2 = (y2 * top) / bottom + top-- + bottom-- + } + + if maxf/(-X+maxf) >= y1*math.Exp(math.Log(y2)*nmin1inv) { + Vprime = math.Exp(math.Log(prng.Float64()) * nmin1inv) + break + } + + Vprime = math.Exp(math.Log(prng.Float64()) * ninv) + } + + j += S + 1 + + choose(j) + + max = -S + (-1 + max) + maxf = negSreal + (-1.0 + maxf) + want-- + wantf-- + ninv = nmin1inv + + qu1 = -S + qu1 + qu1real = negSreal + qu1real + + threshold += negalphainv + } + + if want > 1 { + methodA(prng, want, max, j, choose) // if i>0 then n has been decremented + } else { + S := int(math.Floor(float64(max) * Vprime)) + + j += S + 1 + + choose(j) + } +} + +// methodA is the simpler fallback algorithm used when Algorithm D's optimizations are not beneficial. +func methodA(prng *rand.Rand, want, max int, j int, choose func(n int)) { + for want >= 2 { + j++ + V := prng.Float64() + quot := float64(max-want) / float64(max) + for quot > V { + j++ + max-- + quot *= float64(max - want) + quot /= float64(max) + } + choose(j) + max-- + want-- + } + + S := int(math.Floor(float64(max) * prng.Float64())) + j += S + 1 + choose(j) +} diff --git a/tpl/collections/vitter_test.go b/tpl/collections/vitter_test.go new file mode 100644 index 000000000..961d1d367 --- /dev/null +++ b/tpl/collections/vitter_test.go @@ -0,0 +1,95 @@ +// This is just a temporary fork of https://github.com/josharian/vitter (ISC License, https://github.com/josharian/vitter/blob/main/LICENSE) +// +// This file will be removed once https://github.com/josharian/vitter/issues/1 is resolved. + +package collections + +import ( + "fmt" + "math/rand/v2" + "reflect" + "testing" + "time" +) + +var goldenTests = []struct { + seed int64 + k, max int + want []int +}{ + {2, 10, 100, []int{6, 20, 34, 45, 58, 59, 64, 69, 70, 72}}, + {3, 10, 100, []int{8, 11, 22, 26, 30, 40, 74, 76, 93, 95}}, + {4, 5, 1000, []int{183, 283, 443, 501, 531}}, + {5, 15, 100000, []int{12984, 17778, 20370, 23830, 27120, 33258, 45718, 50064, 57096, 58580, 80960, 84396, 84594, 95561, 97687}}, +} + +func TestGolden(t *testing.T) { + for _, test := range goldenTests { + prng := rand.New(rand.NewPCG(uint64(test.seed), 0)) + var got []int + testD(prng, t, test.k, test.max, func(n int) { + got = append(got, n) + }) + if !reflect.DeepEqual(got, test.want) { + t.Errorf("golden(%d, %d, %d) = %#v want %#v", test.seed, test.k, test.max, got, test.want) + } + } +} + +func TestInspectCounts(t *testing.T) { + prng := rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), uint64(time.Now().UnixNano()))) + const max = 100 + const k = 10 + const iters = 10000 + counts := make([]int, max) + for i := 0; i < iters; i++ { + testD(prng, t, k, max, func(n int) { + counts[n]++ + }) + } + for i := range counts { + counts[i] -= (iters * k / max) + } + t.Log(counts) +} + +func testD(prng *rand.Rand, tb testing.TB, want, max int, choose func(n int)) { + prev := -1 + got := want + _d(prng, want, max, func(x int) { + if x <= prev { + tb.Fatalf("backwards: %d then %d", prev, x) + } + if x < 0 || x >= max { + tb.Fatalf("bad selection: %d", x) + } + prev = x + got-- + if got < 0 { + tb.Fatal("choose called too many times") + } + choose(x) + }) + if got != 0 { + tb.Fatal("choose not called enough times") + } +} + +func TestWantIsMax(t *testing.T) { + // Ensure that when want == max, we get all indices. + prng := rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), uint64(time.Now().UnixNano()))) + const n = 10000 + testD(prng, t, n, n, func(n int) {}) +} + +func BenchmarkD(b *testing.B) { + prng := rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), uint64(time.Now().UnixNano()))) + // TODO: count rng calls? + for _, want := range []int{1, 100, 10000} { + b.Run(fmt.Sprintf("want=%d", want), func(b *testing.B) { + for i := 0; i < b.N; i++ { + _d(prng, want, 1000000, func(int) {}) + } + }) + } +}