implemented seed pass-through
This commit is contained in:
@@ -6,15 +6,19 @@ import threading
|
||||
import concurrent.futures as concfut
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import sys
|
||||
|
||||
from . import batch
|
||||
from . import plan
|
||||
|
||||
|
||||
def execute(exp_file):
|
||||
dispatcher = load(exp_file)
|
||||
dispatcher.start()
|
||||
dispatcher.join()
|
||||
|
||||
|
||||
def load(exp_file):
|
||||
exp_plan = plan.Plan(exp_file, multiprocessing.Lock())
|
||||
|
||||
@@ -35,6 +39,7 @@ def load(exp_file):
|
||||
|
||||
return Dispatcher(exp_mod, exp_plan, num_workers)
|
||||
|
||||
|
||||
class Dispatcher (threading.Thread):
|
||||
def __init__(self, exp_mod, exp_plan, num_workers):
|
||||
threading.Thread.__init__(self)
|
||||
@@ -46,19 +51,20 @@ class Dispatcher (threading.Thread):
|
||||
self.__exp_mod = exp_mod
|
||||
|
||||
for i in range(self.__num_workers):
|
||||
self.__workers.append(Worker(exp_mod, exp_plan, i))
|
||||
self.__workers.append(Worker(exp_mod,
|
||||
exp_plan,
|
||||
i))
|
||||
|
||||
def run(self):
|
||||
for worker in self.__workers:
|
||||
worker.start()
|
||||
|
||||
def wait_to_continue(workers, stop_called):
|
||||
any_worker_alive = lambda: any(map(lambda w: w.is_alive(), workers))
|
||||
def any_worker_alive(): any(map(lambda w: w.is_alive(), workers))
|
||||
|
||||
while any_worker_alive() and not stop_called.is_set():
|
||||
time.sleep(0)
|
||||
|
||||
|
||||
waiter = threading.Thread(target=wait_to_continue,
|
||||
args=(self.__workers,
|
||||
self.__stop_called))
|
||||
@@ -70,13 +76,11 @@ class Dispatcher (threading.Thread):
|
||||
for worker in self.__workers:
|
||||
worker.terminate()
|
||||
|
||||
|
||||
for worker in self.__workers:
|
||||
worker.join()
|
||||
|
||||
self.__done()
|
||||
|
||||
|
||||
def stop(self):
|
||||
self.__stop_called.set()
|
||||
|
||||
@@ -102,8 +106,8 @@ class Worker (multiprocessing.Process):
|
||||
|
||||
def run(self):
|
||||
instance = self.__exp_plan.next()
|
||||
print(instance)
|
||||
while instance != None:
|
||||
|
||||
while instance is not None:
|
||||
instance_state = self.__exp_plan.load_instance_state(instance)
|
||||
|
||||
self.__exp_mod.run(instance,
|
||||
@@ -121,5 +125,3 @@ class Worker (multiprocessing.Process):
|
||||
def terminate(self):
|
||||
self.__exp_plan.delete()
|
||||
multiprocessing.Process.terminate(self)
|
||||
|
||||
|
||||
|
32
alma/plan.py
32
alma/plan.py
@@ -3,9 +3,12 @@ import json
|
||||
import os
|
||||
import multiprocessing
|
||||
import threading
|
||||
import random
|
||||
import sys
|
||||
|
||||
from . import batch
|
||||
|
||||
|
||||
class Plan:
|
||||
def __init__(self, experiment=None, lock=None):
|
||||
self.experiment = None
|
||||
@@ -16,12 +19,11 @@ class Plan:
|
||||
|
||||
self.__instance_id_counter = 0
|
||||
|
||||
self.__lock = threading.Lock() if lock == None else lock
|
||||
self.__lock = threading.Lock() if lock is None else lock
|
||||
|
||||
if experiment:
|
||||
self.create(experiment)
|
||||
|
||||
|
||||
def create(self, experiment):
|
||||
self.experiment = pl.Path(experiment).resolve()
|
||||
self.__set_file()
|
||||
@@ -48,21 +50,23 @@ class Plan:
|
||||
|
||||
instances = batch.load(pl.Path(exp_obj["batch"]))
|
||||
|
||||
if iterations_left == None:
|
||||
if iterations_left is None:
|
||||
|
||||
if "iterations" in exp_obj:
|
||||
iterations_left = exp_obj["iterations"] - 1
|
||||
if "seed" in exp_obj:
|
||||
random.seed(exp_obj["seed"])
|
||||
|
||||
else:
|
||||
iterations_left = 0
|
||||
|
||||
|
||||
content["pending"] = instances
|
||||
content["iterations_left"] = iterations_left
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def __set_file(self):
|
||||
if self.experiment == None:
|
||||
if self.experiment is None:
|
||||
self.file = None
|
||||
else:
|
||||
exp_path = pl.Path(self.experiment)
|
||||
@@ -92,11 +96,12 @@ class Plan:
|
||||
if "instance_states" in content:
|
||||
self.instance_states = content["instance_states"]
|
||||
|
||||
if "rand_state" in content:
|
||||
random.setstate(self.__arr2tup(content["rand_state"]))
|
||||
|
||||
def __is_finished(self):
|
||||
return False if self.file.is_file() else True
|
||||
|
||||
|
||||
def next(self):
|
||||
|
||||
with self.__lock:
|
||||
@@ -110,6 +115,7 @@ class Plan:
|
||||
|
||||
next_instance = self.pending_instances.pop()
|
||||
next_instance["id"] = self.__instance_id_counter
|
||||
next_instance["seed"] = random.randint(0, sys.maxsize)
|
||||
self.__instance_id_counter += 1
|
||||
|
||||
self.assigned_instances.append(next_instance)
|
||||
@@ -148,6 +154,8 @@ class Plan:
|
||||
content["pending"] = self.pending_instances
|
||||
all_done = False
|
||||
|
||||
content["rand_state"] = random.getstate()
|
||||
|
||||
if all_done:
|
||||
if self.iterations_left > 0:
|
||||
self.__load_next_iteration()
|
||||
@@ -169,6 +177,16 @@ class Plan:
|
||||
with open(self.file, "w") as pfile:
|
||||
pfile.write(json.dumps(content))
|
||||
|
||||
def __serialize_rand_state(self):
|
||||
return json.dumps(random.getstate())
|
||||
|
||||
def __arr2tup(self, arr):
|
||||
for i, e in enumerate(arr):
|
||||
if type(e) is list:
|
||||
arr[i] = self.__arr2tup(e)
|
||||
|
||||
return tuple(arr)
|
||||
|
||||
def save_instance_state(self, instance, data):
|
||||
|
||||
with self.__lock:
|
||||
|
Reference in New Issue
Block a user