transform

old TransE-like models
git clone https://esimon.eu/repos/transform.git
Log | Files | Refs | README

commit 9abf658fda3a8d55c9a1edfe44f7f0617cd086b4
parent 0f5f3fdc2a50c5dfecf7f71bd9b3cf60a9fb6eee
Author: Étienne Simon <esimon@esimon.eu>
Date:   Wed, 16 Apr 2014 18:11:21 +0200

Fix training set splitting for minibatch

Diffstat:
Mdataset.py | 12+++++++-----
Mmodel.py | 2+-
2 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/dataset.py b/dataset.py @@ -48,11 +48,13 @@ class Dataset(object): # Yielding batches ls = numpy.linspace(0, self.train_size, 1+self.train_size/batch_size) for i in xrange(len(ls)-1): - left_positive = train_left[ls[i]:ls[i+1]] - right_positive = train_right[ls[i]:ls[i+1]] - left_negative = corrupted_left[ls[i]:ls[i+1]] - right_negative = corrupted_right[ls[i]:ls[i+1]] - relation = train_relation[ls[i]:ls[i+1]] + f = int(ls[i]) + t = int(ls[i+1]) + left_positive = train_left[f:t] + right_positive = train_right[f:t] + left_negative = corrupted_left[f:t] + right_negative = corrupted_right[f:t] + relation = train_relation[f:t] yield (relation, left_positive, right_positive, left_negative, right_negative) def iterate(self, name, batch_size): diff --git a/model.py b/model.py @@ -117,7 +117,7 @@ class Model(object): """ Compute the mean rank and top 10 on a given data. """ batch_size = self.hyperparameters['test_batch_size'] count, mean, top10 = 0, 0, 0 - for (relation, left, right) in self.dataset.iterate(name, batch_size): + for (relation, left, right) in self.dataset.iterate(name, batch_size): # TODO Test symmetric scores = None for entities in self.dataset.universe_minibatch(batch_size): if left.shape != entities.shape: