You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

269 lines
8.5 KiB

6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
5 years ago
  1. import configparser
  2. import os
  3. import argparse
  4. import pymongo
  5. import ssl
  6. import mysql.connector
  7. import networkx as nx
  8. from . import queries
  9. from . import graph
  10. import minorminer
  11. from tqdm import tqdm
  12. def readConfig(configFilePath):
  13. config = configparser.ConfigParser()
  14. if os.path.isfile(configFilePath):
  15. config.read(configFilePath)
  16. return config
  17. class ArgParser:
  18. def __init__(self):
  19. self.__flags = {}
  20. self.__parser = argparse.ArgumentParser()
  21. self.__instanceDirArgSet = False
  22. self.__config = None
  23. self.__parsedArgs = {}
  24. def addArg(self, alias,
  25. shortFlag,
  26. longFlag,
  27. help,
  28. type,
  29. default=None,
  30. ignoreDatabaseConfig=False):
  31. self.__flags[alias] = {"longFlag": longFlag,
  32. "hasDefault": False,
  33. "ignoreDatabaseConfig": ignoreDatabaseConfig,
  34. "type": type}
  35. if default != None:
  36. self.__flags[alias]["hasDefault"] = True
  37. self.__parser.add_argument("-%s" % shortFlag,
  38. "--%s" % longFlag,
  39. help=help,
  40. type=type,
  41. default=default)
  42. def addInstanceDirArg(self):
  43. self.__instanceDirArgSet = True
  44. self.addArg(alias="datasetDir", shortFlag="d", longFlag="dataset_dir",
  45. help="the base direcotry of the dataset; if this flag is given the others can be omitted",
  46. type=str, ignoreDatabaseConfig=True)
  47. def parse(self):
  48. self.__parsedArgs = {}
  49. args = vars(self.__parser.parse_args())
  50. if self.__instanceDirArgSet:
  51. self.__config = readConfig(os.path.join(args["dataset_dir"],
  52. "dataset.config"))
  53. self.__parseDatasetConfig()
  54. for alias, flag in self.__flags.items():
  55. self.__parsedArgs[alias] = self.__processFlag(args, flag)
  56. self.__config = None
  57. return self.__parsedArgs
  58. def __parseDatasetConfig(self):
  59. for flag, value in self.__config["STRUCTURE"].items():
  60. self.__parsedArgs[flag] = value
  61. def __processFlag(self, args, flag):
  62. longFlag = flag["longFlag"]
  63. tmpValue = self.__parsedArgs[longFlag] if longFlag in self.__parsedArgs else None
  64. if flag["ignoreDatabaseConfig"] == True:
  65. tmpValue = None
  66. if args[longFlag]:
  67. tmpValue = args[longFlag]
  68. if tmpValue == None:
  69. tmpValue = flag["type"](input("pass arguement %s: " % longFlag))
  70. return tmpValue
  71. def getDBContext(dbConfigPath):
  72. dbContext = {}
  73. dbContext["client"] = connect_to_instance_pool(dbConfigPath)
  74. dbContext["db"] = dbContext["client"]["experiments"]
  75. dbContext["instances"] = dbContext["db"]["instances"]
  76. dbContext["experimentScopes"] = dbContext["db"]["experiment_scopes"]
  77. return dbContext
  78. def connect_to_instance_pool(dbConfigPath):
  79. dbConf = readConfig(dbConfigPath)
  80. client = pymongo.MongoClient(
  81. "mongodb://%s:%s@%s:%s/%s"
  82. % ( dbConf["INSTANCE_POOL"]["user"],
  83. dbConf["INSTANCE_POOL"]["pw"],
  84. dbConf["INSTANCE_POOL"]["url"],
  85. dbConf["INSTANCE_POOL"]["port"],
  86. dbConf["INSTANCE_POOL"]["database"]),
  87. ssl=True,
  88. ssl_cert_reqs=ssl.CERT_NONE)
  89. return client[dbConf["INSTANCE_POOL"]["database"]]
  90. def connect_to_experimetns_db(dbConfigPath):
  91. dbConfig = readConfig(dbConfigPath)
  92. return mysql.connector.connect(
  93. host=dbConfig["EXPERIMENT_DB"]["url"],
  94. port=dbConfig["EXPERIMENT_DB"]["port"],
  95. user=dbConfig["EXPERIMENT_DB"]["user"],
  96. password=dbConfig["EXPERIMENT_DB"]["pw"],
  97. database=dbConfig["EXPERIMENT_DB"]["database"]
  98. )
  99. def frange(start, stop, steps):
  100. while start < stop:
  101. yield start
  102. start += steps
  103. def create_experiment_scope(db, description, name):
  104. experimentScope = {}
  105. experimentScope["instances"] = []
  106. experimentScope["description"] = description
  107. experimentScope["_id"] = name.strip()
  108. db["experiment_scopes"].insert_one(experimentScope)
  109. def write_instance_to_pool_db(db, instance):
  110. instance_document = instance.writeJSONLike()
  111. result = db["instances"].insert_one(instance_document)
  112. return result.inserted_id
  113. def add_instance_to_experiment_scope(db, scope_name, instance_id):
  114. db["experiment_scopes"].update_one(
  115. {"_id": scope_name},
  116. {"$push": {"instances": instance_id}}
  117. )
  118. def write_qubo_to_pool_db(collection, qubo, sat_instance_id):
  119. doc = {}
  120. doc["instance"] = sat_instance_id
  121. doc["description"] = {"<qubo>": "<entrys>",
  122. "<entrys>": "<entry><entrys> | <entry> | \"\"",
  123. "<entry>": "<coupler><energy>",
  124. "<energy>": "<real_number>",
  125. "<coupler>": "<node><node>",
  126. "<node>": "<clause><literal>",
  127. "<clause>": "<natural_number>",
  128. "<literal>": "<integer>"}
  129. doc["qubo"] = __qubo_to_JSON(qubo)
  130. collection.insert_one(doc)
  131. def __qubo_to_JSON(qubo):
  132. quboJSON = []
  133. for coupler, value in qubo.items():
  134. quboJSON.append([coupler, float(value)])
  135. return quboJSON
  136. def write_wmis_embedding_to_pool_db(collection, qubo_id, solver_graph_id, embedding):
  137. if not __embedding_entry_exists(collection, qubo_id, solver_graph_id):
  138. __prepare_new_wmis_embedding_entry(collection, qubo_id, solver_graph_id)
  139. collection.update_one(
  140. {"qubo": qubo_id, "solver_graph": solver_graph_id},
  141. {"$push": {"embeddings": __embedding_to_array(embedding)}}
  142. )
  143. def __embedding_entry_exists(collection, qubo_id, solver_graph_id):
  144. filter = {"qubo": qubo_id, "solver_graph": solver_graph_id}
  145. if collection.count_documents(filter) > 0:
  146. return True
  147. return False
  148. def __prepare_new_wmis_embedding_entry(collection, qubo_id, solver_graph_id):
  149. doc = {}
  150. doc["qubo"] = qubo_id
  151. doc["solver_graph"] = solver_graph_id
  152. doc["description"] = {"<embedding>": "<chains>",
  153. "<chains>": "<chain><chains> | \"\"",
  154. "<chain>" : "<original_node><chimera_nodes>",
  155. "<chimera_nodes>": "<chimera_node><chimera_nodes> | \"\""}
  156. doc["embeddings"] = []
  157. collection.insert_one(doc)
  158. def __embedding_to_array(embedding):
  159. emb_arr = []
  160. for node, chain in embedding.items():
  161. emb_arr.append([node, chain])
  162. return emb_arr
  163. def write_solver_graph_to_pool_db(collection, graph):
  164. data = nx.node_link_data(graph)
  165. id = queries.get_id_of_solver_graph(collection, data)
  166. if id != None:
  167. return id
  168. doc = {}
  169. doc["data"] = data
  170. return collection.insert_one(doc).inserted_id
  171. def find_wmis_embeddings_for_scope(db, scope, solver_graph):
  172. solver_graph_id = write_solver_graph_to_pool_db(db["solver_graphs"],
  173. solver_graph)
  174. qubos = queries.WMIS_scope_query(db)
  175. qubos.query(scope)
  176. for qubo, qubo_id in tqdm(qubos):
  177. if not __embedding_entry_exists(db["embeddings"], qubo_id, solver_graph_id):
  178. nx_qubo = graph.qubo_to_nx_graph(qubo)
  179. emb = minorminer.find_embedding(nx_qubo.edges(),
  180. solver_graph.edges(),
  181. return_overlap=True)
  182. if emb[1] == 1:
  183. write_wmis_embedding_to_pool_db(db["embeddings"],
  184. qubo_id,
  185. solver_graph_id,
  186. emb[0])
  187. def save_simulated_annealing_result(collection, result, solver_input, emb_list_index):
  188. doc = {}
  189. doc["data"] = result.to_serializable()
  190. doc["instance"] = solver_input["instance_id"]
  191. doc["embedding"] = {
  192. "embedding_id": solver_input["embeddings_id"],
  193. "list_index": emb_list_index
  194. }
  195. collection.insert_one(doc)