commit 9abf658fda3a8d55c9a1edfe44f7f0617cd086b4
parent 0f5f3fdc2a50c5dfecf7f71bd9b3cf60a9fb6eee
Author: Étienne Simon <esimon@esimon.eu>
Date: Wed, 16 Apr 2014 18:11:21 +0200
Fix training set splitting for minibatch
Diffstat:
2 files changed, 8 insertions(+), 6 deletions(-)
diff --git a/dataset.py b/dataset.py
@@ -48,11 +48,13 @@ class Dataset(object):
# Yielding batches
ls = numpy.linspace(0, self.train_size, 1+self.train_size/batch_size)
for i in xrange(len(ls)-1):
- left_positive = train_left[ls[i]:ls[i+1]]
- right_positive = train_right[ls[i]:ls[i+1]]
- left_negative = corrupted_left[ls[i]:ls[i+1]]
- right_negative = corrupted_right[ls[i]:ls[i+1]]
- relation = train_relation[ls[i]:ls[i+1]]
+ f = int(ls[i])
+ t = int(ls[i+1])
+ left_positive = train_left[f:t]
+ right_positive = train_right[f:t]
+ left_negative = corrupted_left[f:t]
+ right_negative = corrupted_right[f:t]
+ relation = train_relation[f:t]
yield (relation, left_positive, right_positive, left_negative, right_negative)
def iterate(self, name, batch_size):
diff --git a/model.py b/model.py
@@ -117,7 +117,7 @@ class Model(object):
""" Compute the mean rank and top 10 on a given data. """
batch_size = self.hyperparameters['test_batch_size']
count, mean, top10 = 0, 0, 0
- for (relation, left, right) in self.dataset.iterate(name, batch_size):
+ for (relation, left, right) in self.dataset.iterate(name, batch_size): # TODO Test symmetric
scores = None
for entities in self.dataset.universe_minibatch(batch_size):
if left.shape != entities.shape: