Spaces:
Running
Running
nssharmaofficial
commited on
Commit
•
0716538
1
Parent(s):
f7300ff
Fix forward method
Browse files- source/predict_sample.py +1 -1
source/predict_sample.py
CHANGED
@@ -61,7 +61,7 @@ def generate_caption(image: torch.Tensor,
|
|
61 |
features = features.repeat(SEQ_LENGTH, 1, 1)
|
62 |
# features : (SEQ_LENGTH, B=1, IMAGE_EMB_DIM)
|
63 |
|
64 |
-
next_id_pred, (hidden, cell) = image_decoder.forward(lstm_input,
|
65 |
# next_id_pred : (SEQ_LENGTH, 1, VOCAB_SIZE)
|
66 |
|
67 |
next_id_pred = next_id_pred[-1, 0, :]
|
|
|
61 |
features = features.repeat(SEQ_LENGTH, 1, 1)
|
62 |
# features : (SEQ_LENGTH, B=1, IMAGE_EMB_DIM)
|
63 |
|
64 |
+
next_id_pred, (hidden, cell) = image_decoder.forward(lstm_input, hidden, cell)
|
65 |
# next_id_pred : (SEQ_LENGTH, 1, VOCAB_SIZE)
|
66 |
|
67 |
next_id_pred = next_id_pred[-1, 0, :]
|