First commit

This commit is contained in:
2025-09-03 19:24:26 +10:00
parent 2bf7123ae2
commit 1e1402795f
30 changed files with 7582 additions and 33 deletions

View File

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,19 @@
import math
import torch
class UniformSampler:
def __init__(self, timesteps = 1000):
self.timesteps = timesteps
def sample(self, batch_size, device):
return torch.randint(0, self.timesteps, (batch_size,), device=device)
class LogitNormalSampler:
def __init__(self, timesteps = 1000, m = 0, s = 1):
self.timesteps = timesteps
timesteps = torch.linspace(0, 1, timesteps)
logit = torch.log(timesteps / (1 - timesteps))
self.prob = torch.exp(-0.5 * (logit - m) ** 2 / s ** 2) / (s * math.sqrt(2 * math.pi))
def sample(self, batch_size, device):
return torch.multinomial(self.prob, batch_size, replacement=True).to(device)