[docs]defonehot_from_logits(logits,eps=0.0):""" Given batch of logits, return one-hot sample using epsilon greedy strategy (based on given epsilon) """# get best (according to current policy) actions in one-hot formargmax_acs=(logits==logits.max(1,keepdim=True)[0]).float()ifeps==0.0:returnargmax_acs# get random actions in one-hot formrand_acs=Variable(torch.eye(logits.shape[1])[[np.random.choice(range(logits.shape[1]),size=logits.shape[0])]],requires_grad=False,)# chooses between best and random actions using epsilon greedyreturntorch.stack([argmax_acs[i]ifr>epselserand_acs[i]fori,rinenumerate(torch.rand(logits.shape[0]))])
[docs]defsample_gumbel(shape,device,eps=1e-20,tens_type=torch.FloatTensor):"""Sample from Gumbel(0, 1)"""U=Variable(tens_type(*shape).uniform_(),requires_grad=False).to(device)return-torch.log(-torch.log(U+eps)+eps)
[docs]defgumbel_softmax_sample(logits,temperature,device):"""Draw a sample from the Gumbel-Softmax distribution"""y=logits+sample_gumbel(logits.shape,tens_type=type(logits.data),device=device)returnF.softmax(y/temperature,dim=1)
[docs]defgumbel_softmax(logits,device,temperature=1.0,hard=False):"""Sample from the Gumbel-Softmax distribution and optionally discretize. Args: logits: [batch_size, n_class] unnormalized log-probs temperature: non-negative scalar hard: if True, take argmax, but differentiate w.r.t. soft sample y Returns: [batch_size, n_class] sample from the Gumbel-Softmax distribution. If hard=True, then the returned sample will be one-hot, otherwise it will be a probabilitiy distribution that sums to 1 across classes """y=gumbel_softmax_sample(logits,temperature,device=device)ifhard:y_hard=onehot_from_logits(y)y=(y_hard-y).detach()+yreturny