File tree Expand file tree Collapse file tree 1 file changed +1
-14
lines changed Expand file tree Collapse file tree 1 file changed +1
-14
lines changed Original file line number Diff line number Diff line change @@ -337,20 +337,7 @@ def num_hidden_layers(self):
337
337
from apex .normalization .fused_layer_norm import FusedLayerNorm as XLNetLayerNorm
338
338
except (ImportError , AttributeError ) as e :
339
339
logger .info ("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex ." )
340
- class XLNetLayerNorm (nn .Module ):
341
- def __init__ (self , d_model , eps = 1e-12 ):
342
- """Construct a layernorm module in the TF style (epsilon inside the square root).
343
- """
344
- super (XLNetLayerNorm , self ).__init__ ()
345
- self .weight = nn .Parameter (torch .ones (d_model ))
346
- self .bias = nn .Parameter (torch .zeros (d_model ))
347
- self .variance_epsilon = eps
348
-
349
- def forward (self , x ):
350
- u = x .mean (- 1 , keepdim = True )
351
- s = (x - u ).pow (2 ).mean (- 1 , keepdim = True )
352
- x = (x - u ) / torch .sqrt (s + self .variance_epsilon )
353
- return self .weight * x + self .bias
340
+ from torch .nn import LayerNorm as XLNetLayerNorm
354
341
355
342
class XLNetRelativeAttention (nn .Module ):
356
343
def __init__ (self , config ):
You can’t perform that action at this time.
0 commit comments