diff --git a/generator.py b/generator.py index fa211fb81..16407b4f6 100644 --- a/generator.py +++ b/generator.py @@ -3,13 +3,17 @@ import re, os, shutil, string, sys def generateElements(elementFiles, outputCpp, outputH): elementClasses = dict() + baseClasses = dict() elementHeader = """#ifndef ELEMENTCLASSES_H - #define ELEMENTCLASSES_H - #include - #include "simulation/Element.h" - #include "simulation/elements/Element.h" - """ +#define ELEMENTCLASSES_H + +#include + +#include "simulation/Element.h" +#include "simulation/elements/Element.h" + +""" directives = [] @@ -28,42 +32,92 @@ def generateElements(elementFiles, outputCpp, outputH): classDirectives = [] for d in directives: if d[0] == "ElementClass": - elementClasses[d[1]] = [] - elementHeader += "#define %s %s\n" % (d[2], d[3]) d[3] = string.atoi(d[3]) classDirectives.append(d) - + + elementIDs = sorted(classDirectives, key=lambda directive: directive[3]) + + for d in elementIDs: + tmpClass = d[1] + newClass = "" + baseClass = "Element" + if ':' in tmpClass: + classBits = tmpClass.split(':') + newClass = classBits[0] + baseClass = classBits[1] + else: + newClass = tmpClass + + elementClasses[newClass] = [] + baseClasses[newClass] = baseClass + elementHeader += "#define %s %s\n" % (d[2], d[3]) + for d in directives: if d[0] == "ElementHeader": - elementClasses[d[1]].append(string.join(d[2:], " ")+";") + tmpClass = d[1] + newClass = "" + baseClass = "Element" + if ':' in tmpClass: + classBits = tmpClass.split(':') + newClass = classBits[0] + baseClass = classBits[1] + else: + newClass = tmpClass + elementClasses[newClass].append(string.join(d[2:], " ")+";") - for className, classMembers in elementClasses.items(): - elementHeader += """class {0}: public Element - {{ - public: - {0}(); - virtual ~{0}(); - {1} - }}; - """.format(className, string.join(classMembers, "\n")) + #for className, classMembers in elementClasses.items(): + for d in elementIDs: + tmpClass = d[1] + newClass = "" + baseClass = "Element" + if ':' in tmpClass: + classBits = tmpClass.split(':') + newClass = classBits[0] + baseClass = classBits[1] + else: + newClass = tmpClass - elementHeader += """std::vector GetElements(); - #endif + className = newClass + classMembers = elementClasses[newClass] + elementBase = baseClass + elementHeader += """ +class {0}: public {1} +{{ +public: + {0}(); + virtual ~{0}(); + {2} +}}; + """.format(className, elementBase, string.join(classMembers, "\n\t")) + + elementHeader += """ +std::vector GetElements(); + +#endif """ elementContent = """#include "ElementClasses.h" - std::vector GetElements() - { - std::vector elements; + +std::vector GetElements() +{ + std::vector elements; """; - elementIDs = sorted(classDirectives, key=lambda directive: directive[3]) for d in elementIDs: - elementContent += """ elements.push_back(%s()); - """ % (d[1]) + tmpClass = d[1] + newClass = "" + baseClass = "Element" + if ':' in tmpClass: + classBits = tmpClass.split(':') + newClass = classBits[0] + baseClass = classBits[1] + else: + newClass = tmpClass + elementContent += """elements.push_back(%s()); + """ % (newClass) elementContent += """ return elements; - } +} """; f = open(outputH, "w")