Skip to content

Commit 51d7232

Browse files
committed
fix c_model.py
1 parent 0a266a1 commit 51d7232

File tree

3 files changed

+3
-35
lines changed

3 files changed

+3
-35
lines changed

w2v_classifier/ex_classifier/c_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def __init__(self, input_dim, n_classes):
1717
nn.Tanh(),
1818
nn.Linear(self.hidden_dim, self.n_classes))
1919

20-
def forward(self, dataset):
21-
sentence_vec, target = dataset
20+
def forward(self, sentence_vec, target=None):
21+
2222
predicted = torch.tensor([[0.9, 0.1]], requires_grad=True)
2323
predicted_value, predicted_class = torch.max(predicted, 1)
2424

w2v_classifier/ex_classifier/new_c_model.py

-30
This file was deleted.

w2v_classifier/test_classifier.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44

55
from ex_classifier.a_sent2vec import *
66
from ex_classifier.b_dataloader import SentDataloader
7-
# from ex_classifier.c_model import SentClassifier
8-
from ex_classifier.new_c_model import SentClassifier
9-
7+
from ex_classifier.c_model import SentClassifier
108
from ex_classifier.d_predict import sent_predictor
119

1210

0 commit comments

Comments
 (0)