diff --git a/variantworks/networks.py b/variantworks/networks.py index 0bead87..f0a39ae 100644 --- a/variantworks/networks.py +++ b/variantworks/networks.py @@ -143,28 +143,22 @@ def output_ports(self): 'output_logit': NeuralType(('B', 'W', 'D'), LogitsType()), } - def __init__(self, input_feature_size, num_output_logits, - gru_size=128, gru_layers=2, apply_softmax=False): + def __init__(self, sequence_length, input_feature_size, num_output_logits): """Construct an Consensus RNN NeMo instance. Args: + sequence_length : Length of sequence to feed into RNN. input_feature_size : Length of input feature set. num_output_logits : Number of output classes of classifier. - gru_size : Number of units in RNN - gru_layers : Number of layers in RNN - apply_softmax : Apply softmax to the output of the classifier. Returns: Instance of class. """ super().__init__() self.num_output_logits = num_output_logits - self.apply_softmax = apply_softmax - self.gru_size = gru_size - self.gru_layers = gru_layers - self.gru = nn.GRU(input_feature_size, gru_size, gru_layers, batch_first=True, bidirectional=True) - self.classifier = nn.Linear(2 * gru_size, self.num_output_logits) # 2* for bidirectional + self.gru = nn.GRU(input_feature_size, 128, 2, batch_first=True, bidirectional=True) + self.classifier = nn.Linear(2 * 128, self.num_output_logits) self._device = torch.device( "cuda" if self.placement == DeviceType.GPU else "cpu") @@ -181,6 +175,70 @@ def forward(self, encoding): """ encoding, h_n = self.gru(encoding) encoding = self.classifier(encoding) - if self.apply_softmax: - encoding = F.softmax(encoding, dim=2) + return encoding + + +class ConsensusCNN(TrainableNM): + """A Neural Module for training a Consensus Attention Model.""" + + @property + @add_port_docs() + def input_ports(self): + """Return definitions of module input ports. + + Returns: + Module input ports. + """ + return { + "encoding": NeuralType(('B', 'W', 'C'), ChannelType()), + } + + @property + @add_port_docs() + def output_ports(self): + """Return definitions of module output ports. + + Returns: + Module output ports. + """ + return { + # Variant type + 'output_logit': NeuralType(('B', 'W', 'D'), LogitsType()), + } + + def __init__(self, sequence_length, input_feature_size, num_output_logits): + """Construct an Consensus RNN NeMo instance. + + Args: + sequence_length : Length of sequence to feed into RNN. + input_feature_size : Length of input feature set. + num_output_logits : Number of output classes of classifier. + + Returns: + Instance of class. + """ + super().__init__() + self.num_output_logits = num_output_logits + self.conv1 = nn.Conv1d(input_feature_size, 128, kernel_size=1, padding=0) + self.gru = nn.GRU(128, 16, 1, batch_first=True, bidirectional=True) + self.classifier = nn.Linear(32, self.num_output_logits) + + self._device = torch.device( + "cuda" if self.placement == DeviceType.GPU else "cpu") + self.to(self._device) + + def forward(self, encoding): + """Abstract function to run the network. + + Args: + encoding : Input sequence to run network on. + + Returns: + Output of forward pass. + """ + encoding = encoding.permute(0, 2, 1) + encoding = self.conv1(encoding) + encoding = encoding.permute(0, 2, 1) + encoding, h_n = self.gru(encoding) + encoding = self.classifier(encoding) return encoding