You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

141 lines
4.5 KiB

6 years ago
  1. #!/usr/bin/env python3
  2. import argparse
  3. import os
  4. import glob
  5. import json
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. import configparser
  9. import scriptUtils
  10. def main():
  11. args = __parseArguments()
  12. __stats(args["comparisonDir"], args["outputDir"])
  13. def __parseArguments():
  14. parser = scriptUtils.ArgParser()
  15. parser.addInstanceDirArg()
  16. parser.addArg(alias="comparisonDir", shortFlag="c", longFlag="comparison_dir",
  17. help="the direcotry with all comparison files", type=str)
  18. parser.addArg(alias="outputDir", shortFlag="s", longFlag="comparison_stats_dir",
  19. help="Directory to store the stats", type=str)
  20. arguments = parser.parse()
  21. arguments["datasetDir"] = os.path.abspath(arguments["datasetDir"])
  22. arguments["comparisonDir"] = os.path.join(arguments["datasetDir"],
  23. arguments["comparisonDir"])
  24. arguments["outputDir"] = os.path.join(arguments["datasetDir"],
  25. arguments["outputDir"])
  26. return arguments
  27. def __stats(comparisonDir, outputDir):
  28. runs = glob.glob(os.path.join(comparisonDir, "run*"))
  29. for run in runs:
  30. stats = __collectStats(run)
  31. print(stats)
  32. runOutputDir = os.path.join(outputDir, os.path.basename(run))
  33. __writeStats(stats, runOutputDir)
  34. def __collectStats(comparisonDir):
  35. files = glob.glob(os.path.join(comparisonDir, "*.cmp"))
  36. stats = {}
  37. stats["match"] = {"count": 0,
  38. "instances": []}
  39. stats["false_positive"] = {"count": 0,
  40. "instances": []}
  41. stats["false_negative"] = {"count": 0,
  42. "instances": []}
  43. stats["unsat"] = {"count": 0,
  44. "instances": []}
  45. for path in files:
  46. comparison = __readComparison(path)
  47. minisat_satisfiable = comparison["minisat_satisfiable"]
  48. qubo_satisfiable = comparison["qubo_satisfiable"]
  49. instanceName = str(os.path.basename(path)).split(".")[0]
  50. if minisat_satisfiable == qubo_satisfiable:
  51. stats["match"]["count"] += 1
  52. stats["match"]["instances"].append(instanceName)
  53. elif minisat_satisfiable == False and qubo_satisfiable == True:
  54. stats["false_positive"]["count"] += 1
  55. stats["false_positive"]["instances"].append(instanceName)
  56. elif minisat_satisfiable == True and qubo_satisfiable == False:
  57. stats["false_negative"]["count"] += 1
  58. stats["false_negative"]["instances"].append(instanceName)
  59. if not minisat_satisfiable:
  60. stats["unsat"]["count"] += 1
  61. stats["unsat"]["instances"].append(instanceName)
  62. return stats
  63. def __readComparison(path):
  64. cmpFile = open(path, "r")
  65. comparison = json.load(cmpFile)
  66. cmpFile.close()
  67. return comparison
  68. def __writeStats(stats, outputDir):
  69. if not os.path.exists(outputDir):
  70. os.makedirs(outputDir)
  71. with open(os.path.join(outputDir,"statusCollection"), "w+") as statusFile:
  72. statusFile.write(json.dumps(stats))
  73. fig = plt.figure()
  74. ax = fig.add_subplot(111)
  75. matchCount = stats["match"]["count"]
  76. falseNegativeCount = stats["false_negative"]["count"]
  77. falsePositiveCount = stats["false_positive"]["count"]
  78. numInstances = matchCount + falseNegativeCount + falsePositiveCount
  79. matchBar = ax.bar(x=0, height=matchCount)
  80. falsePositiveBar = ax.bar(x=1, height=falsePositiveCount)
  81. falseNegativeBar = ax.bar(x=1,
  82. height=falseNegativeCount,
  83. bottom=falsePositiveCount)
  84. ax.axhline(y=matchCount, linestyle="--", color="gray")
  85. ax.axhline(y=falseNegativeCount, linestyle="--", color="gray")
  86. plt.ylabel("SAT Instanzen")
  87. plt.title("Verlgeich Minisat / WMIS qubo mit qbsolv")
  88. plt.xticks([0, 1], ("Gleiches Ergebnis", "Unterschiedliches Ergebnis"))
  89. plt.yticks([0, matchCount, falseNegativeCount, numInstances])
  90. plt.legend((matchBar, falsePositiveBar, falseNegativeBar),
  91. ("Gleiches Ergebnis",
  92. "False Positive",
  93. "False Negative"))
  94. plt.savefig(os.path.join(outputDir, "stats.png"))
  95. if __name__ == "__main__":
  96. main()