Browse Source

implemented seed pass-through

master
Tom Krüger 10 months ago
parent
commit
c7fb1740d9
2 changed files with 77 additions and 57 deletions
  1. +21
    -19
      alma/experiment.py
  2. +56
    -38
      alma/plan.py

+ 21
- 19
alma/experiment.py View File

@ -6,15 +6,19 @@ import threading
import concurrent.futures as concfut import concurrent.futures as concfut
import os import os
import time import time
import random
import sys
from . import batch from . import batch
from . import plan from . import plan
def execute(exp_file): def execute(exp_file):
dispatcher = load(exp_file) dispatcher = load(exp_file)
dispatcher.start() dispatcher.start()
dispatcher.join() dispatcher.join()
def load(exp_file): def load(exp_file):
exp_plan = plan.Plan(exp_file, multiprocessing.Lock()) exp_plan = plan.Plan(exp_file, multiprocessing.Lock())
@ -32,9 +36,10 @@ def load(exp_file):
num_workers = os.cpu_count() num_workers = os.cpu_count()
else: else:
num_workers = int(exp_obj["workers"]) num_workers = int(exp_obj["workers"])
return Dispatcher(exp_mod, exp_plan, num_workers) return Dispatcher(exp_mod, exp_plan, num_workers)
class Dispatcher (threading.Thread): class Dispatcher (threading.Thread):
def __init__(self, exp_mod, exp_plan, num_workers): def __init__(self, exp_mod, exp_plan, num_workers):
threading.Thread.__init__(self) threading.Thread.__init__(self)
@ -46,36 +51,35 @@ class Dispatcher (threading.Thread):
self.__exp_mod = exp_mod self.__exp_mod = exp_mod
for i in range(self.__num_workers): for i in range(self.__num_workers):
self.__workers.append(Worker(exp_mod, exp_plan, i))
self.__workers.append(Worker(exp_mod,
exp_plan,
i))
def run(self): def run(self):
for worker in self.__workers: for worker in self.__workers:
worker.start() worker.start()
def wait_to_continue(workers, stop_called): def wait_to_continue(workers, stop_called):
any_worker_alive = lambda: any(map(lambda w: w.is_alive(), workers))
def any_worker_alive(): any(map(lambda w: w.is_alive(), workers))
while any_worker_alive() and not stop_called.is_set(): while any_worker_alive() and not stop_called.is_set():
time.sleep(0) time.sleep(0)
waiter = threading.Thread(target=wait_to_continue, waiter = threading.Thread(target=wait_to_continue,
args=(self.__workers, args=(self.__workers,
self.__stop_called)) self.__stop_called))
waiter.start() waiter.start()
waiter.join() waiter.join()
if self.__stop_called.is_set(): if self.__stop_called.is_set():
for worker in self.__workers: for worker in self.__workers:
worker.terminate() worker.terminate()
for worker in self.__workers:
for worker in self.__workers:
worker.join() worker.join()
self.__done() self.__done()
def stop(self): def stop(self):
self.__stop_called.set() self.__stop_called.set()
@ -95,31 +99,29 @@ class Dispatcher (threading.Thread):
class Worker (multiprocessing.Process): class Worker (multiprocessing.Process):
def __init__(self, exp_mod, exp_plan, id): def __init__(self, exp_mod, exp_plan, id):
multiprocessing.Process.__init__(self) multiprocessing.Process.__init__(self)
self.__exp_mod = exp_mod self.__exp_mod = exp_mod
self.__exp_plan = exp_plan self.__exp_plan = exp_plan
self.__id = id self.__id = id
def run(self): def run(self):
instance = self.__exp_plan.next() instance = self.__exp_plan.next()
print(instance)
while instance != None:
while instance is not None:
instance_state = self.__exp_plan.load_instance_state(instance) instance_state = self.__exp_plan.load_instance_state(instance)
self.__exp_mod.run(instance, self.__exp_mod.run(instance,
lambda data: self.__exp_plan.save_instance_state( lambda data: self.__exp_plan.save_instance_state(
instance, instance,
data data
),
),
instance_state, instance_state,
worker_id=self.__id) worker_id=self.__id)
self.__exp_plan.done_with(instance) self.__exp_plan.done_with(instance)
instance = self.__exp_plan.next() instance = self.__exp_plan.next()
def terminate(self): def terminate(self):
self.__exp_plan.delete() self.__exp_plan.delete()
multiprocessing.Process.terminate(self) multiprocessing.Process.terminate(self)

+ 56
- 38
alma/plan.py View File

@ -3,9 +3,12 @@ import json
import os import os
import multiprocessing import multiprocessing
import threading import threading
import random
import sys
from . import batch from . import batch
class Plan: class Plan:
def __init__(self, experiment=None, lock=None): def __init__(self, experiment=None, lock=None):
self.experiment = None self.experiment = None
@ -13,15 +16,14 @@ class Plan:
self.pending_instances = [] self.pending_instances = []
self.assigned_instances = [] self.assigned_instances = []
self.instance_states = {} self.instance_states = {}
self.__instance_id_counter = 0 self.__instance_id_counter = 0
self.__lock = threading.Lock() if lock == None else lock
self.__lock = threading.Lock() if lock is None else lock
if experiment: if experiment:
self.create(experiment) self.create(experiment)
def create(self, experiment): def create(self, experiment):
self.experiment = pl.Path(experiment).resolve() self.experiment = pl.Path(experiment).resolve()
self.__set_file() self.__set_file()
@ -40,62 +42,65 @@ class Plan:
with self.__lock: with self.__lock:
self.__update_file() self.__update_file()
def __create_content(self, iterations_left = None):
def __create_content(self, iterations_left=None):
content = {} content = {}
with open(self.experiment, "r") as expf: with open(self.experiment, "r") as expf:
exp_obj = json.loads(expf.read()) exp_obj = json.loads(expf.read())
instances = batch.load(pl.Path(exp_obj["batch"])) instances = batch.load(pl.Path(exp_obj["batch"]))
if iterations_left == None:
if iterations_left is None:
if "iterations" in exp_obj: if "iterations" in exp_obj:
iterations_left = exp_obj["iterations"] - 1 iterations_left = exp_obj["iterations"] - 1
if "seed" in exp_obj:
random.seed(exp_obj["seed"])
else: else:
iterations_left = 0 iterations_left = 0
content["pending"] = instances content["pending"] = instances
content["iterations_left"] = iterations_left content["iterations_left"] = iterations_left
return content return content
def __set_file(self): def __set_file(self):
if self.experiment == None:
if self.experiment is None:
self.file = None self.file = None
else: else:
exp_path = pl.Path(self.experiment) exp_path = pl.Path(self.experiment)
self.file = exp_path.parent / (exp_path.stem + ".plan")
self.file = exp_path.parent / (exp_path.stem + ".plan")
def __load(self): def __load(self):
self.pending_instances = [] self.pending_instances = []
self.assigned_instances = [] self.assigned_instances = []
if not self.file.is_file(): if not self.file.is_file():
return return
with open(self.file, "r") as pfile: with open(self.file, "r") as pfile:
content = json.loads(pfile.read()) content = json.loads(pfile.read())
if "assigned" in content: if "assigned" in content:
self.assigned_instances = content["assigned"] self.assigned_instances = content["assigned"]
self.__instance_id_counter = max(map(lambda i: i["id"], self.assigned_instances)) + 1 self.__instance_id_counter = max(map(lambda i: i["id"], self.assigned_instances)) + 1
if "pending" in content: if "pending" in content:
self.pending_instances = content["pending"] self.pending_instances = content["pending"]
if "iterations_left" in content: if "iterations_left" in content:
self.iterations_left = content["iterations_left"] self.iterations_left = content["iterations_left"]
if "instance_states" in content: if "instance_states" in content:
self.instance_states = content["instance_states"] self.instance_states = content["instance_states"]
if "rand_state" in content:
random.setstate(self.__arr2tup(content["rand_state"]))
def __is_finished(self): def __is_finished(self):
return False if self.file.is_file() else True return False if self.file.is_file() else True
def next(self): def next(self):
@ -107,25 +112,26 @@ class Plan:
self.__load_next_iteration() self.__load_next_iteration()
else: else:
return None return None
next_instance = self.pending_instances.pop() next_instance = self.pending_instances.pop()
next_instance["id"] = self.__instance_id_counter next_instance["id"] = self.__instance_id_counter
next_instance["seed"] = random.randint(0, sys.maxsize)
self.__instance_id_counter += 1 self.__instance_id_counter += 1
self.assigned_instances.append(next_instance) self.assigned_instances.append(next_instance)
self.__update_file() self.__update_file()
return next_instance return next_instance
def done_with(self, instance): def done_with(self, instance):
with self.__lock: with self.__lock:
self.__load() self.__load()
self.assigned_instances = list(filter(lambda i: i["id"] != instance["id"], self.assigned_instances = list(filter(lambda i: i["id"] != instance["id"],
self.assigned_instances ))
self.assigned_instances))
if str(instance["id"]) in self.instance_states: if str(instance["id"]) in self.instance_states:
self.instance_states.pop(str(instance["id"])) self.instance_states.pop(str(instance["id"]))
@ -133,13 +139,13 @@ class Plan:
def __update_file(self): def __update_file(self):
content = {} content = {}
all_done = True all_done = True
content["iterations_left"] = self.iterations_left content["iterations_left"] = self.iterations_left
content["instance_states"] = self.instance_states content["instance_states"] = self.instance_states
if len(self.assigned_instances) > 0: if len(self.assigned_instances) > 0:
content["assigned"] = self.assigned_instances content["assigned"] = self.assigned_instances
all_done = False all_done = False
@ -147,7 +153,9 @@ class Plan:
if len(self.pending_instances) > 0: if len(self.pending_instances) > 0:
content["pending"] = self.pending_instances content["pending"] = self.pending_instances
all_done = False all_done = False
content["rand_state"] = random.getstate()
if all_done: if all_done:
if self.iterations_left > 0: if self.iterations_left > 0:
self.__load_next_iteration() self.__load_next_iteration()
@ -156,9 +164,9 @@ class Plan:
else: else:
self.__write_content(content) self.__write_content(content)
def __load_next_iteration(self):
def __load_next_iteration(self):
content = self.__create_content(self.iterations_left - 1) content = self.__create_content(self.iterations_left - 1)
self.pending_instances = content["pending"] self.pending_instances = content["pending"]
self.iterations_left = content["iterations_left"] self.iterations_left = content["iterations_left"]
@ -168,26 +176,36 @@ class Plan:
with open(self.file, "w") as pfile: with open(self.file, "w") as pfile:
pfile.write(json.dumps(content)) pfile.write(json.dumps(content))
def __serialize_rand_state(self):
return json.dumps(random.getstate())
def __arr2tup(self, arr):
for i, e in enumerate(arr):
if type(e) is list:
arr[i] = self.__arr2tup(e)
return tuple(arr)
def save_instance_state(self, instance, data): def save_instance_state(self, instance, data):
with self.__lock: with self.__lock:
self.__load() self.__load()
self.instance_states[str(instance["id"])] = data self.instance_states[str(instance["id"])] = data
self.__update_file() self.__update_file()
def load_instance_state(self, instance): def load_instance_state(self, instance):
with self.__lock: with self.__lock:
self.__load() self.__load()
if str(instance["id"]) in self.instance_states: if str(instance["id"]) in self.instance_states:
return self.instance_states[str(instance["id"])] return self.instance_states[str(instance["id"])]
else: else:
return "" return ""
def delete(self): def delete(self):
with self.__lock: with self.__lock:
self.__load() self.__load()


Loading…
Cancel
Save