mirror of
https://github.com/gohugoio/hugo.git
synced 2025-08-29 22:29:56 +02:00
tpl/collections: Add collections.D using Vitter's Method D for sequential random sampling
This commit is contained in:
committed by
GitHub
parent
84dd495f2b
commit
1ba80874e4
@@ -20,13 +20,27 @@ import (
|
||||
// Cache is a simple thread safe cache backed by a map.
|
||||
type Cache[K comparable, T any] struct {
|
||||
m map[K]T
|
||||
opts CacheOptions
|
||||
hasBeenInitialized bool
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCache creates a new Cache.
|
||||
// CacheOptions are the options for the Cache.
|
||||
type CacheOptions struct {
|
||||
// If set, the cache will not grow beyond this size.
|
||||
Size uint64
|
||||
}
|
||||
|
||||
var defaultCacheOptions = CacheOptions{}
|
||||
|
||||
// NewCache creates a new Cache with default options.
|
||||
func NewCache[K comparable, T any]() *Cache[K, T] {
|
||||
return &Cache[K, T]{m: make(map[K]T)}
|
||||
return &Cache[K, T]{m: make(map[K]T), opts: defaultCacheOptions}
|
||||
}
|
||||
|
||||
// NewCacheWithOptions creates a new Cache with the given options.
|
||||
func NewCacheWithOptions[K comparable, T any](opts CacheOptions) *Cache[K, T] {
|
||||
return &Cache[K, T]{m: make(map[K]T), opts: opts}
|
||||
}
|
||||
|
||||
// Delete deletes the given key from the cache.
|
||||
@@ -65,6 +79,7 @@ func (c *Cache[K, T]) GetOrCreate(key K, create func() (T, error)) (T, error) {
|
||||
if err != nil {
|
||||
return v, err
|
||||
}
|
||||
c.clearIfNeeded()
|
||||
c.m[key] = v
|
||||
return v, nil
|
||||
}
|
||||
@@ -127,7 +142,15 @@ func (c *Cache[K, T]) SetIfAbsent(key K, value T) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache[K, T]) clearIfNeeded() {
|
||||
if c.opts.Size > 0 && uint64(len(c.m)) >= c.opts.Size {
|
||||
// clear the map
|
||||
clear(c.m)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache[K, T]) set(key K, value T) {
|
||||
c.clearIfNeeded()
|
||||
c.m[key] = value
|
||||
}
|
||||
|
||||
|
57
common/maps/cache_test.go
Normal file
57
common/maps/cache_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
// Copyright 2024 The Hugo Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package maps
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
qt "github.com/frankban/quicktest"
|
||||
)
|
||||
|
||||
func TestCacheSize(t *testing.T) {
|
||||
c := qt.New(t)
|
||||
|
||||
cache := NewCacheWithOptions[string, string](CacheOptions{Size: 10})
|
||||
|
||||
for i := 0; i < 30; i++ {
|
||||
cache.Set(string(rune('a'+i)), "value")
|
||||
}
|
||||
|
||||
c.Assert(len(cache.m), qt.Equals, 10)
|
||||
|
||||
for i := 20; i < 50; i++ {
|
||||
cache.GetOrCreate(string(rune('a'+i)), func() (string, error) {
|
||||
return "value", nil
|
||||
})
|
||||
}
|
||||
|
||||
c.Assert(len(cache.m), qt.Equals, 10)
|
||||
|
||||
for i := 100; i < 200; i++ {
|
||||
cache.SetIfAbsent(string(rune('a'+i)), "value")
|
||||
}
|
||||
|
||||
c.Assert(len(cache.m), qt.Equals, 10)
|
||||
|
||||
cache.InitAndGet("foo", func(
|
||||
get func(key string) (string, bool), set func(key string, value string),
|
||||
) error {
|
||||
for i := 50; i < 100; i++ {
|
||||
set(string(rune('a'+i)), "value")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
c.Assert(len(cache.m), qt.Equals, 10)
|
||||
}
|
4
go.mod
4
go.mod
@@ -74,7 +74,7 @@ require (
|
||||
github.com/yuin/goldmark-emoji v1.0.6
|
||||
go.uber.org/automaxprocs v1.5.3
|
||||
gocloud.dev v0.43.0
|
||||
golang.org/x/exp v0.0.0-20221031165847-c99f073a8326
|
||||
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b
|
||||
golang.org/x/image v0.30.0
|
||||
golang.org/x/mod v0.27.0
|
||||
golang.org/x/net v0.43.0
|
||||
@@ -190,4 +190,4 @@ require (
|
||||
software.sslmate.com/src/go-pkcs12 v0.2.0 // indirect
|
||||
)
|
||||
|
||||
go 1.24
|
||||
go 1.24.0
|
||||
|
4
go.sum
4
go.sum
@@ -580,8 +580,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
|
||||
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
|
||||
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
|
||||
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
|
||||
golang.org/x/exp v0.0.0-20221031165847-c99f073a8326 h1:QfTh0HpN6hlw6D3vu8DAwC8pBIwikq0AI1evdm+FksE=
|
||||
golang.org/x/exp v0.0.0-20221031165847-c99f073a8326/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
|
||||
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b h1:DXr+pvt3nC887026GRP39Ej11UATqWDmWuS99x26cD0=
|
||||
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4=
|
||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
golang.org/x/image v0.0.0-20210220032944-ac19c3e999fb/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
|
@@ -19,7 +19,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"math/rand/v2"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -41,9 +41,12 @@ func New(deps *deps.Deps) *Namespace {
|
||||
}
|
||||
loc := langs.GetLocation(language)
|
||||
|
||||
dCache := maps.NewCacheWithOptions[dKey, []int](maps.CacheOptions{Size: 100})
|
||||
|
||||
return &Namespace{
|
||||
loc: loc,
|
||||
sortComp: compare.New(loc, true),
|
||||
dCache: dCache,
|
||||
deps: deps,
|
||||
}
|
||||
}
|
||||
@@ -52,6 +55,7 @@ func New(deps *deps.Deps) *Namespace {
|
||||
type Namespace struct {
|
||||
loc *time.Location
|
||||
sortComp *compare.Namespace
|
||||
dCache *maps.Cache[dKey, []int]
|
||||
deps *deps.Deps
|
||||
}
|
||||
|
||||
@@ -520,6 +524,29 @@ func (ns *Namespace) Slice(args ...any) any {
|
||||
return collections.Slice(args...)
|
||||
}
|
||||
|
||||
type dKey struct {
|
||||
seed uint64
|
||||
n int
|
||||
hi int
|
||||
}
|
||||
|
||||
// D returns a slice of n unique random numbers in the range [0, hi) using the provded seed,
|
||||
// using J. S. Vitter's Method D for sequential random sampling, from Vitter, J.S.
|
||||
// - An Efficient Algorithm for Sequential Random Sampling - ACM Trans. Math. Software 11 (1985), 37-57.
|
||||
// See https://getkerf.wordpress.com/2016/03/30/the-best-algorithm-no-one-knows-about/
|
||||
func (ns *Namespace) D(seed, n, hi any) []int {
|
||||
key := dKey{seed: cast.ToUint64(seed), n: cast.ToInt(n), hi: cast.ToInt(hi)}
|
||||
v, _ := ns.dCache.GetOrCreate(key, func() ([]int, error) {
|
||||
prng := rand.New(rand.NewPCG(key.seed, 0))
|
||||
result := make([]int, 0, key.n)
|
||||
_d(prng, key.n, key.hi, func(i int) {
|
||||
result = append(result, i)
|
||||
})
|
||||
return result, nil
|
||||
})
|
||||
return v
|
||||
}
|
||||
|
||||
type intersector struct {
|
||||
r reflect.Value
|
||||
seen map[any]bool
|
||||
|
@@ -788,6 +788,35 @@ func TestUniq(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestD(t *testing.T) {
|
||||
t.Parallel()
|
||||
c := qt.New(t)
|
||||
ns := newNs()
|
||||
|
||||
c.Assert(ns.D(42, 5, 100), qt.DeepEquals, []int{24, 34, 66, 82, 96})
|
||||
c.Assert(ns.D(31, 5, 100), qt.DeepEquals, []int{12, 37, 38, 69, 98})
|
||||
}
|
||||
|
||||
func BenchmarkD2(b *testing.B) {
|
||||
ns := newNs()
|
||||
|
||||
runBenchmark := func(seed, n, max int) {
|
||||
name := fmt.Sprintf("n=%d,max=%d", n, max)
|
||||
b.Run(name, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
ns.D(seed, n, max)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
runBenchmark(32, 5, 100)
|
||||
runBenchmark(32, 50, 1000)
|
||||
runBenchmark(32, 10, 10000)
|
||||
runBenchmark(32, 500, 10000)
|
||||
runBenchmark(32, 10, 500000)
|
||||
runBenchmark(32, 5000, 500000)
|
||||
}
|
||||
|
||||
func (x *TstX) TstRp() string {
|
||||
return "r" + x.A
|
||||
}
|
||||
|
149
tpl/collections/vitter.go
Normal file
149
tpl/collections/vitter.go
Normal file
@@ -0,0 +1,149 @@
|
||||
// This is just a temporary fork of https://github.com/josharian/vitter (ISC License, https://github.com/josharian/vitter/blob/main/LICENSE)
|
||||
//
|
||||
// This file will be removed once https://github.com/josharian/vitter/issues/1 is resolved.
|
||||
|
||||
package collections
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
)
|
||||
|
||||
// https://getkerf.wordpress.com/2016/03/30/the-best-algorithm-no-one-knows-about/
|
||||
|
||||
// Copyright Kevin Lawler, released under ISC License
|
||||
|
||||
// _d generates an in-order uniform random sample of size 'want' from the range [0, max) using the provided PRNG.
|
||||
//
|
||||
// Parameters:
|
||||
// - prng: random number generator
|
||||
// - want: number of samples to select
|
||||
// - max: upper bound of the range [0, max) from which to sample
|
||||
// - choose: callback function invoked with each selected index in ascending order
|
||||
//
|
||||
// If the parameters are invalid (want < 0 or want > max), no samples are selected.
|
||||
//
|
||||
// Vitter, J.S. - An Efficient Algorithm for Sequential Random Sampling - ACM Trans. Math. Software 11 (1985), 37-57.
|
||||
func _d(prng *rand.Rand, want, max int, choose func(n int)) {
|
||||
if want <= 0 || want > max {
|
||||
return
|
||||
}
|
||||
// POTENTIAL_OPTIMIZATION_POINT: Christian Neukirchen points out we can replace exp(log(x)*y) by pow(x,y)
|
||||
// POTENTIAL_OPTIMIZATION_POINT: Vitter paper points out an exponentially distributed random var can provide speed ups
|
||||
// 'a' is space allocated for the hand
|
||||
// 'n' is the size of the hand
|
||||
// 'N' is the upper bound on the random card values
|
||||
j := -1
|
||||
qu1 := -want + 1 + max
|
||||
const negalphainv = -13 // threshold parameter from Vitter's paper for algorithm selection
|
||||
threshold := -negalphainv * want
|
||||
|
||||
wantf := float64(want)
|
||||
maxf := float64(max)
|
||||
ninv := 1.0 / wantf
|
||||
var nmin1inv float64
|
||||
Vprime := math.Exp(math.Log(prng.Float64()) * ninv)
|
||||
|
||||
qu1real := -wantf + 1.0 + maxf
|
||||
var U, X, y1, y2, top, bottom, negSreal float64
|
||||
|
||||
for want > 1 && threshold < max {
|
||||
var S int
|
||||
|
||||
nmin1inv = 1.0 / (-1.0 + wantf)
|
||||
|
||||
for {
|
||||
for {
|
||||
X = maxf * (-Vprime + 1.0)
|
||||
S = int(math.Floor(X))
|
||||
|
||||
if S < qu1 {
|
||||
break
|
||||
}
|
||||
|
||||
Vprime = math.Exp(math.Log(prng.Float64()) * ninv)
|
||||
}
|
||||
|
||||
U = prng.Float64()
|
||||
negSreal = float64(-S)
|
||||
y1 = math.Exp(math.Log(U*maxf/qu1real) * nmin1inv)
|
||||
Vprime = y1 * (-X/maxf + 1.0) * (qu1real / (negSreal + qu1real))
|
||||
|
||||
if Vprime <= 1.0 {
|
||||
break
|
||||
}
|
||||
|
||||
y2 = 1.0
|
||||
top = -1.0 + maxf
|
||||
var limit int
|
||||
|
||||
if -1+want > S {
|
||||
bottom = -wantf + maxf
|
||||
limit = -S + max
|
||||
} else {
|
||||
bottom = -1.0 + negSreal + maxf
|
||||
limit = qu1
|
||||
}
|
||||
|
||||
for t := max - 1; t >= limit; t-- {
|
||||
y2 = (y2 * top) / bottom
|
||||
top--
|
||||
bottom--
|
||||
}
|
||||
|
||||
if maxf/(-X+maxf) >= y1*math.Exp(math.Log(y2)*nmin1inv) {
|
||||
Vprime = math.Exp(math.Log(prng.Float64()) * nmin1inv)
|
||||
break
|
||||
}
|
||||
|
||||
Vprime = math.Exp(math.Log(prng.Float64()) * ninv)
|
||||
}
|
||||
|
||||
j += S + 1
|
||||
|
||||
choose(j)
|
||||
|
||||
max = -S + (-1 + max)
|
||||
maxf = negSreal + (-1.0 + maxf)
|
||||
want--
|
||||
wantf--
|
||||
ninv = nmin1inv
|
||||
|
||||
qu1 = -S + qu1
|
||||
qu1real = negSreal + qu1real
|
||||
|
||||
threshold += negalphainv
|
||||
}
|
||||
|
||||
if want > 1 {
|
||||
methodA(prng, want, max, j, choose) // if i>0 then n has been decremented
|
||||
} else {
|
||||
S := int(math.Floor(float64(max) * Vprime))
|
||||
|
||||
j += S + 1
|
||||
|
||||
choose(j)
|
||||
}
|
||||
}
|
||||
|
||||
// methodA is the simpler fallback algorithm used when Algorithm D's optimizations are not beneficial.
|
||||
func methodA(prng *rand.Rand, want, max int, j int, choose func(n int)) {
|
||||
for want >= 2 {
|
||||
j++
|
||||
V := prng.Float64()
|
||||
quot := float64(max-want) / float64(max)
|
||||
for quot > V {
|
||||
j++
|
||||
max--
|
||||
quot *= float64(max - want)
|
||||
quot /= float64(max)
|
||||
}
|
||||
choose(j)
|
||||
max--
|
||||
want--
|
||||
}
|
||||
|
||||
S := int(math.Floor(float64(max) * prng.Float64()))
|
||||
j += S + 1
|
||||
choose(j)
|
||||
}
|
95
tpl/collections/vitter_test.go
Normal file
95
tpl/collections/vitter_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
// This is just a temporary fork of https://github.com/josharian/vitter (ISC License, https://github.com/josharian/vitter/blob/main/LICENSE)
|
||||
//
|
||||
// This file will be removed once https://github.com/josharian/vitter/issues/1 is resolved.
|
||||
|
||||
package collections
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var goldenTests = []struct {
|
||||
seed int64
|
||||
k, max int
|
||||
want []int
|
||||
}{
|
||||
{2, 10, 100, []int{6, 20, 34, 45, 58, 59, 64, 69, 70, 72}},
|
||||
{3, 10, 100, []int{8, 11, 22, 26, 30, 40, 74, 76, 93, 95}},
|
||||
{4, 5, 1000, []int{183, 283, 443, 501, 531}},
|
||||
{5, 15, 100000, []int{12984, 17778, 20370, 23830, 27120, 33258, 45718, 50064, 57096, 58580, 80960, 84396, 84594, 95561, 97687}},
|
||||
}
|
||||
|
||||
func TestGolden(t *testing.T) {
|
||||
for _, test := range goldenTests {
|
||||
prng := rand.New(rand.NewPCG(uint64(test.seed), 0))
|
||||
var got []int
|
||||
testD(prng, t, test.k, test.max, func(n int) {
|
||||
got = append(got, n)
|
||||
})
|
||||
if !reflect.DeepEqual(got, test.want) {
|
||||
t.Errorf("golden(%d, %d, %d) = %#v want %#v", test.seed, test.k, test.max, got, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInspectCounts(t *testing.T) {
|
||||
prng := rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), uint64(time.Now().UnixNano())))
|
||||
const max = 100
|
||||
const k = 10
|
||||
const iters = 10000
|
||||
counts := make([]int, max)
|
||||
for i := 0; i < iters; i++ {
|
||||
testD(prng, t, k, max, func(n int) {
|
||||
counts[n]++
|
||||
})
|
||||
}
|
||||
for i := range counts {
|
||||
counts[i] -= (iters * k / max)
|
||||
}
|
||||
t.Log(counts)
|
||||
}
|
||||
|
||||
func testD(prng *rand.Rand, tb testing.TB, want, max int, choose func(n int)) {
|
||||
prev := -1
|
||||
got := want
|
||||
_d(prng, want, max, func(x int) {
|
||||
if x <= prev {
|
||||
tb.Fatalf("backwards: %d then %d", prev, x)
|
||||
}
|
||||
if x < 0 || x >= max {
|
||||
tb.Fatalf("bad selection: %d", x)
|
||||
}
|
||||
prev = x
|
||||
got--
|
||||
if got < 0 {
|
||||
tb.Fatal("choose called too many times")
|
||||
}
|
||||
choose(x)
|
||||
})
|
||||
if got != 0 {
|
||||
tb.Fatal("choose not called enough times")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWantIsMax(t *testing.T) {
|
||||
// Ensure that when want == max, we get all indices.
|
||||
prng := rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), uint64(time.Now().UnixNano())))
|
||||
const n = 10000
|
||||
testD(prng, t, n, n, func(n int) {})
|
||||
}
|
||||
|
||||
func BenchmarkD(b *testing.B) {
|
||||
prng := rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), uint64(time.Now().UnixNano())))
|
||||
// TODO: count rng calls?
|
||||
for _, want := range []int{1, 100, 10000} {
|
||||
b.Run(fmt.Sprintf("want=%d", want), func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_d(prng, want, 1000000, func(int) {})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user