commit 744de49cdcd00d5bead21197cc31bf226cdb03c0
parent a6ea206decf38474e3c970077f96fabe40811829
Author: Étienne Simon <esimon@esimon.eu>
Date:   Wed, 16 Apr 2014 16:02:51 +0200
Add FB15k dataset builder
Diffstat:
1 file changed, 102 insertions(+), 0 deletions(-)
diff --git a/utils/build Bordes FB15k.py b/utils/build Bordes FB15k.py
@@ -0,0 +1,102 @@
+#!/usr/bin/env python2
+
+import sys
+import os
+
+urls = [ 'https://www.hds.utc.fr/everest/lib/exe/fetch.php?id=en%3Atranse&cache=cache&media=en:fb15k.tgz' ]
+
+def get_archive(path):
+    import urllib
+
+    class URLopener(urllib.FancyURLopener):
+          def http_error_default(self, url, fp, errcode, errmsg, headers):
+              print >>sys.stderr, 'Error: {0} {1}'.format(errcode, errmsg)
+              raise IOError
+
+    archive = path+'/archive.tgz'
+    downloaded = False
+    for url in urls:
+        print >>sys.stderr, 'Downloading dataset from "{0}"...'.format(url),
+        try:
+            URLopener().retrieve(url, archive)
+            downloaded = True
+            print >>sys.stderr, ' done'
+        except IOError:
+            pass
+
+    if not downloaded:
+        print >>sys.stderr, 'Error: Unable to download dataset.'
+        sys.exit(1)
+
+def get_raw(path):
+    if os.path.isdir(path+'/raw'):
+        return
+
+    get_archive(path)
+
+    print >>sys.stderr, 'Raw files not found, extracting archive...',
+    raw = path+'/raw'
+    os.mkdir(raw)
+
+    import tarfile
+    tar = tarfile.open(path+'/archive.tgz', 'r:gz')
+    tar.extractall(raw)
+    print >>sys.stderr, ' done'
+
+def compile_dataset(path):
+    get_raw(path)
+    prefix = path+'/raw/FB15k/freebase_mtr100_mte100-'
+    suffix = '.txt'
+
+    print >>sys.stderr, 'Reading train file...',
+    with open(prefix+'train'+suffix, 'r') as file:
+        content = map(lambda line: line.rstrip('\n').split('\t'), file.readlines())
+        [left, relations, right] = map(set, zip(*content))
+    entities = left | right
+    print >>sys.stderr, ' done'
+
+    print >>sys.stderr, 'Writting entities...',
+    e2i, i2e, r2i, i2r = {}, {}, {}, {}
+    with open(path+'/entities', 'w') as file:
+        i=0
+        for entity in entities:
+            e2i[entity]=i
+            i2e[i]=entity
+            file.write(entity+'\n')
+            i+=1
+    print >>sys.stderr, ' done ({0} entities written)'.format(i)
+
+    print >>sys.stderr, 'Writting relations...',
+    with open(path+'/relations', 'w') as file:
+        i=0
+        for relation in relations:
+            r2i[relation]=i
+            i2r[i]=relation
+            file.write(relation+'\n')
+            i+=1
+    print >>sys.stderr, ' done ({0} relations written)'.format(i)
+
+    for name in ['train', 'valid', 'test']:
+        print >>sys.stderr, 'Compiling {0}...'.format(name),
+        count = 0
+        with open(prefix+name+suffix, 'r') as infile:
+            with open(path+'/'+name, 'w') as outfile:
+                for line in infile.readlines():
+                    left, relation, right = line.rstrip('\n').split('\t')
+                    if left in e2i and right in e2i and relation in r2i:
+                        outfile.write('{0}\t{1}\t{2}\n'.format(e2i[left], r2i[relation], e2i[right]))
+                    else:
+                        count+=1
+        print >>sys.stderr, ' done ({0} entit{1} removed)'.format(count, 'y' if count<2 else 'ies')
+
+if __name__ == '__main__':
+    if len(sys.argv)<2:
+        print >>sys.stderr, 'Usage: {0} path'.format(sys.argv[0])
+        sys.exit(1)
+
+    path = sys.argv[1]
+    if not os.path.isdir(path):
+        os.mkdir(path)
+
+    compile_dataset(path)
+    print 'Bordes FB15k was successfully built in {0}'.format(path)