Make thread_local replacement for MinGW slightly nicer

This commit is contained in:
Tamás Bálint Misius
2021-10-23 09:38:58 +02:00
parent 0ed8d0a0be
commit 0f2eedd4fb
2 changed files with 73 additions and 68 deletions

View File

@@ -5,58 +5,41 @@
# include <cstdlib> # include <cstdlib>
# include <cassert> # include <cassert>
static pthread_once_t once = PTHREAD_ONCE_INIT; void *ThreadLocalCommon::Get() const
static pthread_key_t key;
struct ThreadLocalCommon
{ {
size_t size; // https://stackoverflow.com/questions/16552710/how-do-you-get-the-start-and-end-addresses-of-a-custom-elf-section
void (*ctor)(void *); extern ThreadLocalCommon __start_tpt_tls;
void (*dtor)(void *); extern ThreadLocalCommon __stop_tpt_tls;
size_t padding; static pthread_once_t once = PTHREAD_ONCE_INIT;
}; static pthread_key_t key;
static_assert(sizeof(ThreadLocalCommon) == 0x20, "fix me");
struct ThreadLocalEntry struct ThreadLocalEntry
{
void *ptr;
};
// https://stackoverflow.com/questions/16552710/how-do-you-get-the-start-and-end-addresses-of-a-custom-elf-section
extern ThreadLocalCommon __start_tpt_tls;
extern ThreadLocalCommon __stop_tpt_tls;
static void ThreadLocalDestroy(void *opaque)
{
auto *staticsBegin = &__start_tpt_tls;
auto *staticsEnd = &__stop_tpt_tls;
auto staticsCount = staticsEnd - staticsBegin;
auto *liveObjects = reinterpret_cast<ThreadLocalEntry *>(opaque);
if (liveObjects)
{ {
for (auto i = 0; i < staticsCount; ++i) void *ptr;
{ };
if (liveObjects[i].ptr)
{
staticsBegin[i].dtor(liveObjects[i].ptr);
free(liveObjects[i].ptr);
}
}
free(liveObjects);
}
}
static void ThreadLocalCreate()
{
assert(!pthread_key_create(&key, ThreadLocalDestroy));
}
void *ThreadLocalGet(void *opaque)
{
auto *staticsBegin = &__start_tpt_tls; auto *staticsBegin = &__start_tpt_tls;
auto *staticsEnd = &__stop_tpt_tls; auto *staticsEnd = &__stop_tpt_tls;
auto *staticsOpaque = reinterpret_cast<ThreadLocalCommon *>(opaque); pthread_once(&once, []() -> void {
pthread_once(&once, ThreadLocalCreate); assert(!pthread_key_create(&key, [](void *opaque) -> void {
auto *staticsBegin = &__start_tpt_tls;
auto *staticsEnd = &__stop_tpt_tls;
auto staticsCount = staticsEnd - staticsBegin;
auto *liveObjects = reinterpret_cast<ThreadLocalEntry *>(opaque);
if (liveObjects)
{
for (auto i = 0; i < staticsCount; ++i)
{
if (liveObjects[i].ptr)
{
staticsBegin[i].dtor(liveObjects[i].ptr);
free(liveObjects[i].ptr);
}
}
free(liveObjects);
}
}));
});
auto *liveObjects = reinterpret_cast<ThreadLocalEntry *>(pthread_getspecific(key)); auto *liveObjects = reinterpret_cast<ThreadLocalEntry *>(pthread_getspecific(key));
if (!liveObjects) if (!liveObjects)
{ {
@@ -65,7 +48,7 @@ void *ThreadLocalGet(void *opaque)
assert(liveObjects); assert(liveObjects);
assert(!pthread_setspecific(key, reinterpret_cast<void *>(liveObjects))); assert(!pthread_setspecific(key, reinterpret_cast<void *>(liveObjects)));
} }
auto idx = staticsOpaque - staticsBegin; auto idx = this - staticsBegin;
auto &entry = liveObjects[idx]; auto &entry = liveObjects[idx];
if (!entry.ptr) if (!entry.ptr)
{ {
@@ -75,5 +58,4 @@ void *ThreadLocalGet(void *opaque)
} }
return entry.ptr; return entry.ptr;
} }
#endif #endif

View File

@@ -4,39 +4,62 @@
#ifdef __MINGW32__ #ifdef __MINGW32__
# include <cstddef> # include <cstddef>
template<class Type> class ThreadLocalCommon
class ThreadLocal
{ {
static void Ctor(Type *type) ThreadLocalCommon(const ThreadLocalCommon &other) = delete;
{ ThreadLocalCommon &operator =(const ThreadLocalCommon &other) = delete;
new (type) Type();
}
static void Dtor(Type *type) protected:
{ size_t size;
type->~Type(); void (*ctor)(void *);
} void (*dtor)(void *);
size_t size = sizeof(Type);
void (*ctor)(Type *) = Ctor;
void (*dtor)(Type *) = Dtor;
size_t padding; size_t padding;
void *Get() const;
public: public:
Type *operator &() ThreadLocalCommon() = default;
static constexpr size_t Alignment = 0x20;
};
// * If this fails, add or remove padding fields, possibly change Alignment to a larger power of 2.
static_assert(sizeof(ThreadLocalCommon) == ThreadLocalCommon::Alignment, "fix me");
template<class Type>
class ThreadLocal : public ThreadLocalCommon
{
static void Ctor(void *type)
{ {
static_assert(sizeof(ThreadLocal<Type>) == 0x20, "fix me"); new(type) Type();
void *ThreadLocalGet(void *opaque);
return reinterpret_cast<Type *>(ThreadLocalGet(reinterpret_cast<void *>(this)));
} }
operator Type &() static void Dtor(void *type)
{
reinterpret_cast<Type *>(type)->~Type();
}
public:
ThreadLocal()
{
// * If this fails, you're out of luck.
static_assert(sizeof(ThreadLocal<Type>) == sizeof(ThreadLocalCommon), "fix me");
size = sizeof(Type);
ctor = Ctor;
dtor = Dtor;
}
Type *operator &() const
{
return reinterpret_cast<Type *>(Get());
}
operator Type &() const
{ {
return *(this->operator &()); return *(this->operator &());
} }
}; };
# define THREAD_LOCAL(Type, tl) ThreadLocal<Type> tl __attribute__((section("tpt_tls"))) __attribute__((aligned(0x20))) # define THREAD_LOCAL(Type, tl) const ThreadLocal<Type> tl __attribute__((section("tpt_tls"), aligned(ThreadLocalCommon::Alignment)))
#else #else
# define THREAD_LOCAL(Type, tl) thread_local Type tl # define THREAD_LOCAL(Type, tl) thread_local Type tl
#endif #endif