diff options
Diffstat (limited to 'scilo')
-rw-r--r-- | scilo/npz.py | 147 |
1 files changed, 147 insertions, 0 deletions
diff --git a/scilo/npz.py b/scilo/npz.py new file mode 100644 index 0000000..795badc --- /dev/null +++ b/scilo/npz.py @@ -0,0 +1,147 @@ +''' +scilo - A scientific workflow and efficiency library +Copyright (C) 2012 Joseph Hunkeler <jhunkeler@gmail.com> + +This file is part of scilo. + +scilo is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +scilo is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with scilo. If not, see <http://www.gnu.org/licenses/>. +''' +import database +import numpy +import glob +import os + +class settings(object): + def __init__(self, path_start): + # Define generic directory structure + self.path = path_start + self.directories = { + 'data':False, + 'result':False, + 'npz':False + } + for key in self.directories.iterkeys(): + d = os.path.join(self.path, key) + self.directories[key] = d + + +class mtimedb: + def __init__(self, **kwargs): + self.settings = kwargs + self._path_database = os.path.join(self.settings['npz'], 'npz_mtime.db') + self.db = database.s3(self._path_database) + self.populate() + + def populate(self): + #if not os.path.exists(self._path_database): + if os.path.getsize(self._path_database) == 0: + print("Creating file tracking database...") + try: + self.db.cursor.execute("CREATE TABLE npz(file, mtime)") + for f in glob.glob(os.path.join(self.settings['data'], "*.*")): + print("File: %s\tmtime: %f" % (os.path.basename(f), os.path.getmtime(f))) + self.insert(f, os.path.getmtime(f)) + self.db.connection.commit() + except: + raise + return + + def insert(self, path, mtime): + values = (path, mtime,) + self.db.cursor.execute("INSERT INTO npz VALUES (?,?)", values) + self.db.connection.commit() + + def update(self, path, stored, current): + values = (path, current, path, stored) + self.db.cursor("UPDATE npz SET file=?, mtime=? WHERE file==? AND mtime==?", values) + self.db.connection.commit() + print("'%s' updated mtime: %f" % (path, current)) + + def delete(self, path): + values = (path,) + self.db.cursor.execute("DELETE FROM npz WHERE file==?", (values)) + self.db.connection.commit() + print("'%s' removed from mtime database" % path) + return + + def check(self): + mtime_stored = [] + mtime_current = [] + self.db.cursor.execute("SELECT file, mtime FROM npz") + files = glob.glob(os.path.join(self.settings['data'], '*.*')) + + for f in files: + mtime_current.append([f, os.path.getmtime(f)]) + + for f, mtime in self.db.cursor.fetchall(): + mtime_stored.append([str(f), mtime]) + + + for stored_file, stored_mtime in mtime_stored: + for current_file, current_mtime in mtime_current: + if not os.path.exists(stored_file) or not os.path.exists(current_file): + print("Missing data file: '%s'" % stored_file) + self.delete(stored_file) + self.drop(stored_file) + break + if current_file == stored_file: + if current_mtime != stored_mtime: + print("'%s' differs" % current_file) + self.update(current_file, stored_mtime, current_mtime) + print("Rebuilding numpy cache for '%s'" % current_file) + self.build(current_file) + return + +class cache: + def __init__(self, **kwargs): + self.settings = kwargs + self.files = glob.glob(os.path.join(self.settings['npz'], "*.npz")) + self.files_total = len(self.files) + pass + + def build(self, path): + ''' Generate 'path' npz file in npz directory''' + temp = numpy.loadtxt(path) + if numpy.savez(os.path.join(self.settings['npz'], os.path.basename(path)), temp) == False: + return False + return True + + def drop(self, path): + ''' Remove 'path' from npz directory ''' + # For security reasons, you are only allowed to unlink files in the 'npz' directory + if os.path.dirname(path) == 'npz': + print("Unlinking '%s'" % (path)) + os.unlink(path) + + def drop_all(self): + ''' Remove all npz files ''' + files = glob.glob(os.path.join(self.settings['npz'], '*.npz')) + if files: + [os.unlink(f) for f in files] + + def populate(self): + files = glob.glob(os.path.join(self.settings['data'], '*.*')) + file_total = len(files) + file_current = 1 + + for f in files: + exists = os.path.exists(os.path.join(self.settings['npz'], os.path.basename(f) + '.npz')) + if exists: + file_total -= 1 + continue + print("NPZ %d of %d: '%s'" % (file_current, file_total, os.path.basename(f))), + if not self.build(f): + print("... FAIL") + print("") + file_current += 1 |