Skip to content

Commit 1d438f1

Browse files
committed
[XLNet] Use pytorch's layernorm like in BERT
See #1089 cc @thomwolf @LysandreJik Also @dhpollack
1 parent 574c5b3 commit 1d438f1

File tree

1 file changed

+1
-14
lines changed

1 file changed

+1
-14
lines changed

pytorch_transformers/modeling_xlnet.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -337,20 +337,7 @@ def num_hidden_layers(self):
337337
from apex.normalization.fused_layer_norm import FusedLayerNorm as XLNetLayerNorm
338338
except (ImportError, AttributeError) as e:
339339
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
340-
class XLNetLayerNorm(nn.Module):
341-
def __init__(self, d_model, eps=1e-12):
342-
"""Construct a layernorm module in the TF style (epsilon inside the square root).
343-
"""
344-
super(XLNetLayerNorm, self).__init__()
345-
self.weight = nn.Parameter(torch.ones(d_model))
346-
self.bias = nn.Parameter(torch.zeros(d_model))
347-
self.variance_epsilon = eps
348-
349-
def forward(self, x):
350-
u = x.mean(-1, keepdim=True)
351-
s = (x - u).pow(2).mean(-1, keepdim=True)
352-
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
353-
return self.weight * x + self.bias
340+
from torch.nn import LayerNorm as XLNetLayerNorm
354341

355342
class XLNetRelativeAttention(nn.Module):
356343
def __init__(self, config):

0 commit comments

Comments
 (0)