tpl/collections: Add collections.D using Vitter's Method D for sequential random sampling

This commit is contained in:
Bjørn Erik Pedersen
2025-08-26 20:37:08 +02:00
committed by GitHub
parent 84dd495f2b
commit 1ba80874e4
8 changed files with 387 additions and 7 deletions

View File

@@ -20,13 +20,27 @@ import (
// Cache is a simple thread safe cache backed by a map.
type Cache[K comparable, T any] struct {
m map[K]T
opts CacheOptions
hasBeenInitialized bool
sync.RWMutex
}
// NewCache creates a new Cache.
// CacheOptions are the options for the Cache.
type CacheOptions struct {
// If set, the cache will not grow beyond this size.
Size uint64
}
var defaultCacheOptions = CacheOptions{}
// NewCache creates a new Cache with default options.
func NewCache[K comparable, T any]() *Cache[K, T] {
return &Cache[K, T]{m: make(map[K]T)}
return &Cache[K, T]{m: make(map[K]T), opts: defaultCacheOptions}
}
// NewCacheWithOptions creates a new Cache with the given options.
func NewCacheWithOptions[K comparable, T any](opts CacheOptions) *Cache[K, T] {
return &Cache[K, T]{m: make(map[K]T), opts: opts}
}
// Delete deletes the given key from the cache.
@@ -65,6 +79,7 @@ func (c *Cache[K, T]) GetOrCreate(key K, create func() (T, error)) (T, error) {
if err != nil {
return v, err
}
c.clearIfNeeded()
c.m[key] = v
return v, nil
}
@@ -127,7 +142,15 @@ func (c *Cache[K, T]) SetIfAbsent(key K, value T) {
}
}
func (c *Cache[K, T]) clearIfNeeded() {
if c.opts.Size > 0 && uint64(len(c.m)) >= c.opts.Size {
// clear the map
clear(c.m)
}
}
func (c *Cache[K, T]) set(key K, value T) {
c.clearIfNeeded()
c.m[key] = value
}

57
common/maps/cache_test.go Normal file
View File

@@ -0,0 +1,57 @@
// Copyright 2024 The Hugo Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package maps
import (
"testing"
qt "github.com/frankban/quicktest"
)
func TestCacheSize(t *testing.T) {
c := qt.New(t)
cache := NewCacheWithOptions[string, string](CacheOptions{Size: 10})
for i := 0; i < 30; i++ {
cache.Set(string(rune('a'+i)), "value")
}
c.Assert(len(cache.m), qt.Equals, 10)
for i := 20; i < 50; i++ {
cache.GetOrCreate(string(rune('a'+i)), func() (string, error) {
return "value", nil
})
}
c.Assert(len(cache.m), qt.Equals, 10)
for i := 100; i < 200; i++ {
cache.SetIfAbsent(string(rune('a'+i)), "value")
}
c.Assert(len(cache.m), qt.Equals, 10)
cache.InitAndGet("foo", func(
get func(key string) (string, bool), set func(key string, value string),
) error {
for i := 50; i < 100; i++ {
set(string(rune('a'+i)), "value")
}
return nil
})
c.Assert(len(cache.m), qt.Equals, 10)
}

4
go.mod
View File

@@ -74,7 +74,7 @@ require (
github.com/yuin/goldmark-emoji v1.0.6
go.uber.org/automaxprocs v1.5.3
gocloud.dev v0.43.0
golang.org/x/exp v0.0.0-20221031165847-c99f073a8326
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b
golang.org/x/image v0.30.0
golang.org/x/mod v0.27.0
golang.org/x/net v0.43.0
@@ -190,4 +190,4 @@ require (
software.sslmate.com/src/go-pkcs12 v0.2.0 // indirect
)
go 1.24
go 1.24.0

4
go.sum
View File

@@ -580,8 +580,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/exp v0.0.0-20221031165847-c99f073a8326 h1:QfTh0HpN6hlw6D3vu8DAwC8pBIwikq0AI1evdm+FksE=
golang.org/x/exp v0.0.0-20221031165847-c99f073a8326/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b h1:DXr+pvt3nC887026GRP39Ej11UATqWDmWuS99x26cD0=
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/image v0.0.0-20210220032944-ac19c3e999fb/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=

View File

@@ -19,7 +19,7 @@ import (
"context"
"errors"
"fmt"
"math/rand"
"math/rand/v2"
"reflect"
"strings"
"time"
@@ -41,9 +41,12 @@ func New(deps *deps.Deps) *Namespace {
}
loc := langs.GetLocation(language)
dCache := maps.NewCacheWithOptions[dKey, []int](maps.CacheOptions{Size: 100})
return &Namespace{
loc: loc,
sortComp: compare.New(loc, true),
dCache: dCache,
deps: deps,
}
}
@@ -52,6 +55,7 @@ func New(deps *deps.Deps) *Namespace {
type Namespace struct {
loc *time.Location
sortComp *compare.Namespace
dCache *maps.Cache[dKey, []int]
deps *deps.Deps
}
@@ -520,6 +524,29 @@ func (ns *Namespace) Slice(args ...any) any {
return collections.Slice(args...)
}
type dKey struct {
seed uint64
n int
hi int
}
// D returns a slice of n unique random numbers in the range [0, hi) using the provded seed,
// using J. S. Vitter's Method D for sequential random sampling, from Vitter, J.S.
// - An Efficient Algorithm for Sequential Random Sampling - ACM Trans. Math. Software 11 (1985), 37-57.
// See https://getkerf.wordpress.com/2016/03/30/the-best-algorithm-no-one-knows-about/
func (ns *Namespace) D(seed, n, hi any) []int {
key := dKey{seed: cast.ToUint64(seed), n: cast.ToInt(n), hi: cast.ToInt(hi)}
v, _ := ns.dCache.GetOrCreate(key, func() ([]int, error) {
prng := rand.New(rand.NewPCG(key.seed, 0))
result := make([]int, 0, key.n)
_d(prng, key.n, key.hi, func(i int) {
result = append(result, i)
})
return result, nil
})
return v
}
type intersector struct {
r reflect.Value
seen map[any]bool

View File

@@ -788,6 +788,35 @@ func TestUniq(t *testing.T) {
}
}
func TestD(t *testing.T) {
t.Parallel()
c := qt.New(t)
ns := newNs()
c.Assert(ns.D(42, 5, 100), qt.DeepEquals, []int{24, 34, 66, 82, 96})
c.Assert(ns.D(31, 5, 100), qt.DeepEquals, []int{12, 37, 38, 69, 98})
}
func BenchmarkD2(b *testing.B) {
ns := newNs()
runBenchmark := func(seed, n, max int) {
name := fmt.Sprintf("n=%d,max=%d", n, max)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
ns.D(seed, n, max)
}
})
}
runBenchmark(32, 5, 100)
runBenchmark(32, 50, 1000)
runBenchmark(32, 10, 10000)
runBenchmark(32, 500, 10000)
runBenchmark(32, 10, 500000)
runBenchmark(32, 5000, 500000)
}
func (x *TstX) TstRp() string {
return "r" + x.A
}

149
tpl/collections/vitter.go Normal file
View File

@@ -0,0 +1,149 @@
// This is just a temporary fork of https://github.com/josharian/vitter (ISC License, https://github.com/josharian/vitter/blob/main/LICENSE)
//
// This file will be removed once https://github.com/josharian/vitter/issues/1 is resolved.
package collections
import (
"math"
"math/rand/v2"
)
// https://getkerf.wordpress.com/2016/03/30/the-best-algorithm-no-one-knows-about/
// Copyright Kevin Lawler, released under ISC License
// _d generates an in-order uniform random sample of size 'want' from the range [0, max) using the provided PRNG.
//
// Parameters:
// - prng: random number generator
// - want: number of samples to select
// - max: upper bound of the range [0, max) from which to sample
// - choose: callback function invoked with each selected index in ascending order
//
// If the parameters are invalid (want < 0 or want > max), no samples are selected.
//
// Vitter, J.S. - An Efficient Algorithm for Sequential Random Sampling - ACM Trans. Math. Software 11 (1985), 37-57.
func _d(prng *rand.Rand, want, max int, choose func(n int)) {
if want <= 0 || want > max {
return
}
// POTENTIAL_OPTIMIZATION_POINT: Christian Neukirchen points out we can replace exp(log(x)*y) by pow(x,y)
// POTENTIAL_OPTIMIZATION_POINT: Vitter paper points out an exponentially distributed random var can provide speed ups
// 'a' is space allocated for the hand
// 'n' is the size of the hand
// 'N' is the upper bound on the random card values
j := -1
qu1 := -want + 1 + max
const negalphainv = -13 // threshold parameter from Vitter's paper for algorithm selection
threshold := -negalphainv * want
wantf := float64(want)
maxf := float64(max)
ninv := 1.0 / wantf
var nmin1inv float64
Vprime := math.Exp(math.Log(prng.Float64()) * ninv)
qu1real := -wantf + 1.0 + maxf
var U, X, y1, y2, top, bottom, negSreal float64
for want > 1 && threshold < max {
var S int
nmin1inv = 1.0 / (-1.0 + wantf)
for {
for {
X = maxf * (-Vprime + 1.0)
S = int(math.Floor(X))
if S < qu1 {
break
}
Vprime = math.Exp(math.Log(prng.Float64()) * ninv)
}
U = prng.Float64()
negSreal = float64(-S)
y1 = math.Exp(math.Log(U*maxf/qu1real) * nmin1inv)
Vprime = y1 * (-X/maxf + 1.0) * (qu1real / (negSreal + qu1real))
if Vprime <= 1.0 {
break
}
y2 = 1.0
top = -1.0 + maxf
var limit int
if -1+want > S {
bottom = -wantf + maxf
limit = -S + max
} else {
bottom = -1.0 + negSreal + maxf
limit = qu1
}
for t := max - 1; t >= limit; t-- {
y2 = (y2 * top) / bottom
top--
bottom--
}
if maxf/(-X+maxf) >= y1*math.Exp(math.Log(y2)*nmin1inv) {
Vprime = math.Exp(math.Log(prng.Float64()) * nmin1inv)
break
}
Vprime = math.Exp(math.Log(prng.Float64()) * ninv)
}
j += S + 1
choose(j)
max = -S + (-1 + max)
maxf = negSreal + (-1.0 + maxf)
want--
wantf--
ninv = nmin1inv
qu1 = -S + qu1
qu1real = negSreal + qu1real
threshold += negalphainv
}
if want > 1 {
methodA(prng, want, max, j, choose) // if i>0 then n has been decremented
} else {
S := int(math.Floor(float64(max) * Vprime))
j += S + 1
choose(j)
}
}
// methodA is the simpler fallback algorithm used when Algorithm D's optimizations are not beneficial.
func methodA(prng *rand.Rand, want, max int, j int, choose func(n int)) {
for want >= 2 {
j++
V := prng.Float64()
quot := float64(max-want) / float64(max)
for quot > V {
j++
max--
quot *= float64(max - want)
quot /= float64(max)
}
choose(j)
max--
want--
}
S := int(math.Floor(float64(max) * prng.Float64()))
j += S + 1
choose(j)
}

View File

@@ -0,0 +1,95 @@
// This is just a temporary fork of https://github.com/josharian/vitter (ISC License, https://github.com/josharian/vitter/blob/main/LICENSE)
//
// This file will be removed once https://github.com/josharian/vitter/issues/1 is resolved.
package collections
import (
"fmt"
"math/rand/v2"
"reflect"
"testing"
"time"
)
var goldenTests = []struct {
seed int64
k, max int
want []int
}{
{2, 10, 100, []int{6, 20, 34, 45, 58, 59, 64, 69, 70, 72}},
{3, 10, 100, []int{8, 11, 22, 26, 30, 40, 74, 76, 93, 95}},
{4, 5, 1000, []int{183, 283, 443, 501, 531}},
{5, 15, 100000, []int{12984, 17778, 20370, 23830, 27120, 33258, 45718, 50064, 57096, 58580, 80960, 84396, 84594, 95561, 97687}},
}
func TestGolden(t *testing.T) {
for _, test := range goldenTests {
prng := rand.New(rand.NewPCG(uint64(test.seed), 0))
var got []int
testD(prng, t, test.k, test.max, func(n int) {
got = append(got, n)
})
if !reflect.DeepEqual(got, test.want) {
t.Errorf("golden(%d, %d, %d) = %#v want %#v", test.seed, test.k, test.max, got, test.want)
}
}
}
func TestInspectCounts(t *testing.T) {
prng := rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), uint64(time.Now().UnixNano())))
const max = 100
const k = 10
const iters = 10000
counts := make([]int, max)
for i := 0; i < iters; i++ {
testD(prng, t, k, max, func(n int) {
counts[n]++
})
}
for i := range counts {
counts[i] -= (iters * k / max)
}
t.Log(counts)
}
func testD(prng *rand.Rand, tb testing.TB, want, max int, choose func(n int)) {
prev := -1
got := want
_d(prng, want, max, func(x int) {
if x <= prev {
tb.Fatalf("backwards: %d then %d", prev, x)
}
if x < 0 || x >= max {
tb.Fatalf("bad selection: %d", x)
}
prev = x
got--
if got < 0 {
tb.Fatal("choose called too many times")
}
choose(x)
})
if got != 0 {
tb.Fatal("choose not called enough times")
}
}
func TestWantIsMax(t *testing.T) {
// Ensure that when want == max, we get all indices.
prng := rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), uint64(time.Now().UnixNano())))
const n = 10000
testD(prng, t, n, n, func(n int) {})
}
func BenchmarkD(b *testing.B) {
prng := rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), uint64(time.Now().UnixNano())))
// TODO: count rng calls?
for _, want := range []int{1, 100, 10000} {
b.Run(fmt.Sprintf("want=%d", want), func(b *testing.B) {
for i := 0; i < b.N; i++ {
_d(prng, want, 1000000, func(int) {})
}
})
}
}