First commit
This commit is contained in:
0
vibevoice/schedule/__init__.py
Normal file
0
vibevoice/schedule/__init__.py
Normal file
1065
vibevoice/schedule/dpm_solver.py
Normal file
1065
vibevoice/schedule/dpm_solver.py
Normal file
File diff suppressed because it is too large
Load Diff
19
vibevoice/schedule/timestep_sampler.py
Normal file
19
vibevoice/schedule/timestep_sampler.py
Normal file
@ -0,0 +1,19 @@
|
||||
import math
|
||||
import torch
|
||||
|
||||
|
||||
class UniformSampler:
|
||||
def __init__(self, timesteps = 1000):
|
||||
self.timesteps = timesteps
|
||||
def sample(self, batch_size, device):
|
||||
return torch.randint(0, self.timesteps, (batch_size,), device=device)
|
||||
|
||||
class LogitNormalSampler:
|
||||
def __init__(self, timesteps = 1000, m = 0, s = 1):
|
||||
self.timesteps = timesteps
|
||||
timesteps = torch.linspace(0, 1, timesteps)
|
||||
logit = torch.log(timesteps / (1 - timesteps))
|
||||
self.prob = torch.exp(-0.5 * (logit - m) ** 2 / s ** 2) / (s * math.sqrt(2 * math.pi))
|
||||
def sample(self, batch_size, device):
|
||||
return torch.multinomial(self.prob, batch_size, replacement=True).to(device)
|
||||
|
||||
Reference in New Issue
Block a user