protoc-gen-lua支持嵌套类型

时间:2021-01-30 15:42:36
#!/usr/bin/env python
# -*- encoding:utf8 -*-
# protoc-gen-erl
# Google's Protocol Buffers project, ported to lua.
# https://code.google.com/p/protoc-gen-lua/
#
# Copyright (c) 2010 , 林卓毅 (Zhuoyi Lin) netsnail@gmail.com
# All rights reserved.
#
# Use, modification and distribution are subject to the "New BSD License"
# as listed at <url: http://www.opensource.org/licenses/bsd-license.php >.

import sys
import os.path as path
from cStringIO import StringIO

import plugin_pb2
import google.protobuf.descriptor_pb2 as descriptor_pb2

_packages = {}
_files = {}
_message = {}

FDP = plugin_pb2.descriptor_pb2.FieldDescriptorProto

if sys.platform == "win32":
    import msvcrt, os
    msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
    msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)

class CppType:
    CPPTYPE_INT32       = 1
    CPPTYPE_INT64       = 2
    CPPTYPE_UINT32      = 3
    CPPTYPE_UINT64      = 4
    CPPTYPE_DOUBLE      = 5
    CPPTYPE_FLOAT       = 6
    CPPTYPE_BOOL        = 7
    CPPTYPE_ENUM        = 8
    CPPTYPE_STRING      = 9
    CPPTYPE_MESSAGE     = 10

CPP_TYPE ={
    FDP.TYPE_DOUBLE         : CppType.CPPTYPE_DOUBLE,
    FDP.TYPE_FLOAT          : CppType.CPPTYPE_FLOAT,
    FDP.TYPE_INT64          : CppType.CPPTYPE_INT64,
    FDP.TYPE_UINT64         : CppType.CPPTYPE_UINT64,
    FDP.TYPE_INT32          : CppType.CPPTYPE_INT32,
    FDP.TYPE_FIXED64        : CppType.CPPTYPE_UINT64,
    FDP.TYPE_FIXED32        : CppType.CPPTYPE_UINT32,
    FDP.TYPE_BOOL           : CppType.CPPTYPE_BOOL,
    FDP.TYPE_STRING         : CppType.CPPTYPE_STRING,
    FDP.TYPE_MESSAGE        : CppType.CPPTYPE_MESSAGE,
    FDP.TYPE_BYTES          : CppType.CPPTYPE_STRING,
    FDP.TYPE_UINT32         : CppType.CPPTYPE_UINT32,
    FDP.TYPE_ENUM           : CppType.CPPTYPE_ENUM,
    FDP.TYPE_SFIXED32       : CppType.CPPTYPE_INT32,
    FDP.TYPE_SFIXED64       : CppType.CPPTYPE_INT64,
    FDP.TYPE_SINT32         : CppType.CPPTYPE_INT32,
    FDP.TYPE_SINT64         : CppType.CPPTYPE_INT64
}

def printerr(*args):
    sys.stderr.write(" ".join(args))
    sys.stderr.write("\n")
    sys.stderr.flush()

class TreeNode(object):
    def __init__(self, name, parent=None, filename=None, package=None):
        super(TreeNode, self).__init__()
        self.child = []
        self.parent = parent
        self.filename = filename
        self.package = package
        if parent:
            self.parent.add_child(self)
        self.name = name

    def add_child(self, child):
        self.child.append(child)

    def find_child(self, child_names):
        if child_names:
            for i in self.child:
                if i.name == child_names[0]:
                    return i.find_child(child_names[1:])
            raise StandardError
        else:
            return self

    def get_child(self, child_name):
        for i in self.child:
            if i.name == child_name:
                return i
        return None

    def get_path(self, end = None):
        pos = self
        out = []
        while pos and pos != end:
            out.append(pos.name)
            pos = pos.parent
        out.reverse()
        return '.'.join(out)

    def get_global_name(self):
        return self.get_path()

    def get_local_name(self):
        pos = self
        while pos.parent:
            pos = pos.parent
            if self.package and pos.name == self.package[-1]:
                break
        return self.get_path(pos)

    def __str__(self):
        return self.to_string(0)

    def __repr__(self):
        return str(self)

    def to_string(self, indent = 0):
        return ' '*indent + '<TreeNode ' + self.name + '(\n' + \
                ','.join([i.to_string(indent + 4) for i in self.child]) + \
                ' '*indent +')>\n'

class Env(object):
    filename = None
    package = None
    extend = None
    descriptor = None
    message = None
    context = None
    register = None
    def __init__(self):
        self.message_tree = TreeNode('')
        self.scope = self.message_tree

    def get_global_name(self):
        return self.scope.get_global_name()

    def get_local_name(self):
        return self.scope.get_local_name()

    def get_ref_name(self, type_name):
        try:
            node = self.lookup_name(type_name)
        except:
            # if the child doesn't be founded, it must be in this file
            return type_name[len('.'.join(self.package)) + 2:]
        if node.filename != self.filename:
            return node.filename + '_pb.' + node.get_local_name()
        return node.get_local_name()

    def lookup_name(self, name):
        names = name.split('.')
        if names[0] == '':
            return self.message_tree.find_child(names[1:])
        else:
            return self.scope.parent.find_child(names)

    def enter_package(self, package):
        if not package:
            return self.message_tree
        names = package.split('.')
        pos = self.message_tree
        for i, name in enumerate(names):
            new_pos = pos.get_child(name)
            if new_pos:
                pos = new_pos
            else:
                return self._build_nodes(pos, names[i:])
        return pos

    def enter_file(self, filename, package):
        self.filename = filename
        self.package = package.split('.')
        self._init_field()
        self.scope = self.enter_package(package)

    def exit_file(self):
        self._init_field()
        self.filename = None
        self.package = []
        self.scope = self.scope.parent

    def enter(self, message_name):
        self.scope = TreeNode(message_name, self.scope, self.filename,
                              self.package)

    def exit(self):
        self.scope = self.scope.parent

    def _init_field(self):
        self.descriptor = []
        self.context = []
        self.message = []
        self.register = []

    def _build_nodes(self, node, names):
        parent = node
        for i in names:
            parent = TreeNode(i, parent, self.filename, self.package)
        return parent

class Writer(object):
    def __init__(self, prefix=None):
        self.io = StringIO()
        self.__indent = ''
        self.__prefix = prefix

    def getvalue(self):
        return self.io.getvalue()

    def __enter__(self):
        self.__indent += '    '
        return self

    def __exit__(self, type, value, trackback):
        self.__indent = self.__indent[:-4]

    def __call__(self, data):
        self.io.write(self.__indent)
        if self.__prefix:
            self.io.write(self.__prefix)
        self.io.write(data)

DEFAULT_VALUE = {
    FDP.TYPE_DOUBLE         : '0.0',
    FDP.TYPE_FLOAT          : '0.0',
    FDP.TYPE_INT64          : '0',
    FDP.TYPE_UINT64         : '0',
    FDP.TYPE_INT32          : '0',
    FDP.TYPE_FIXED64        : '0',
    FDP.TYPE_FIXED32        : '0',
    FDP.TYPE_BOOL           : 'false',
    FDP.TYPE_STRING         : '""',
    FDP.TYPE_MESSAGE        : 'nil',
    FDP.TYPE_BYTES          : '""',
    FDP.TYPE_UINT32         : '0',
    FDP.TYPE_ENUM           : '1',
    FDP.TYPE_SFIXED32       : '0',
    FDP.TYPE_SFIXED64       : '0',
    FDP.TYPE_SINT32         : '0',
    FDP.TYPE_SINT64         : '0',
}

def code_gen_enum_item(index, enum_value, env):
    full_name = env.get_local_name() + '.' + enum_value.name
    obj_name = full_name.upper().replace('.', '_') + '_ENUM'
    env.descriptor.append(
        "local %s = protobuf.EnumValueDescriptor();\n"% obj_name
    )

    context = Writer(obj_name)
    context('.name = "%s"\n' % enum_value.name)
    context('.index = %d\n' % index)
    context('.number = %d\n' % enum_value.number)

    env.context.append(context.getvalue())
    return obj_name

def code_gen_enum(enum_desc, env):
    env.enter(enum_desc.name)
    full_name = env.get_local_name()
    obj_name = full_name.upper().replace('.', '_')
    env.descriptor.append(
        "local %s = protobuf.EnumDescriptor();\n"% obj_name
    )

    context = Writer(obj_name)
    context('.name = "%s"\n' % enum_desc.name)
    context('.full_name = "%s"\n' % env.get_global_name())

    values = []
    for i, enum_value in enumerate(enum_desc.value):
        values.append(code_gen_enum_item(i, enum_value, env))
    context('.values = {%s}\n' % ','.join(values))

    env.context.append(context.getvalue())
    env.exit()
    return obj_name

def code_gen_field(index, field_desc, env):
    full_name = env.get_local_name() + '.' + field_desc.name
    obj_name = full_name.upper().replace('.', '_') + '_FIELD'
    env.descriptor.append(
        "local %s = protobuf.FieldDescriptor();\n"% obj_name
    )

    context = Writer(obj_name)

    context('.name = "%s"\n' % field_desc.name)
    context('.full_name = "%s"\n' % (
        env.get_global_name() + '.' + field_desc.name))
    context('.number = %d\n' % field_desc.number)
    context('.index = %d\n' % index)
    context('.label = %d\n' % field_desc.label)

    if field_desc.HasField("default_value"):
        context('.has_default_value = true\n')
        value = field_desc.default_value
        if field_desc.type == FDP.TYPE_STRING:
            context('.default_value = "%s"\n'%value)
        else:
            context('.default_value = %s\n'%value)
    else:
        context('.has_default_value = false\n')
        if field_desc.label == FDP.LABEL_REPEATED:
            default_value = "{}"
        elif field_desc.HasField('type_name'):
            default_value = "nil"
        else:
            default_value = DEFAULT_VALUE[field_desc.type]
        context('.default_value = %s\n' % default_value)

    if field_desc.HasField('type_name'):
        type_name = env.get_ref_name(field_desc.type_name).upper()
        if field_desc.type == FDP.TYPE_MESSAGE:
            context('.message_type = %s\n' % type_name)
        else:
            context('.enum_type = %s\n' % type_name)

    if field_desc.HasField('extendee'):
        type_name = env.get_ref_name(field_desc.extendee)
        env.register.append(
            "%s.RegisterExtension(%s)\n" % (type_name, obj_name)
        )

    context('.type = %d\n' % field_desc.type)
    context('.cpp_type = %d\n\n' % CPP_TYPE[field_desc.type])
    env.context.append(context.getvalue())
    return obj_name

def code_gen_message(message_descriptor, env, containing_type = None):
    env.enter(message_descriptor.name)
    full_name = env.get_local_name()
    obj_name = full_name.upper().replace('.', '_')
    env.descriptor.append(
        "%s = protobuf.Descriptor();\n"% obj_name
    )

    context = Writer(obj_name)
    context('.name = "%s"\n' % message_descriptor.name)
    context('.full_name = "%s"\n' % env.get_global_name())

    nested_types = []
    for msg_desc in message_descriptor.nested_type:
        msg_name = code_gen_message(msg_desc, env, obj_name)
        nested_types.append(msg_name)
    context('.nested_types = {%s}\n' % ', '.join(nested_types))

    enums = []
    for enum_desc in message_descriptor.enum_type:
        enums.append(code_gen_enum(enum_desc, env))
    context('.enum_types = {%s}\n' % ', '.join(enums))

    fields = []
    for i, field_desc in enumerate(message_descriptor.field):
        fields.append(code_gen_field(i, field_desc, env))

    context('.fields = {%s}\n' % ', '.join(fields))
    if len(message_descriptor.extension_range) > 0:
        context('.is_extendable = true\n')
    else:
        context('.is_extendable = false\n')

    extensions = []
    for i, field_desc in enumerate(message_descriptor.extension):
        extensions.append(code_gen_field(i, field_desc, env))
    context('.extensions = {%s}\n' % ', '.join(extensions))

    if containing_type:
        context('.containing_type = %s\n' % containing_type)

    env.message.append('%s = protobuf.Message(%s)\n' % (full_name,
                                                        obj_name))

    env.context.append(context.getvalue())
    env.exit()
    return obj_name

def write_header(writer):
    writer("""-- Generated By protoc-gen-lua Do not Edit
""")

def code_gen_file(proto_file, env, is_gen):
    filename = path.splitext(proto_file.name)[0]
    env.enter_file(filename, proto_file.package)

    includes = []
    for f in proto_file.dependency:
        inc_file = path.splitext(f)[0]
        includes.append(inc_file)

#    for field_desc in proto_file.extension:
#        code_gen_extensions(field_desc, field_desc.name, env)

    for enum_desc in proto_file.enum_type:
        code_gen_enum(enum_desc, env)
        for enum_value in enum_desc.value:
            env.message.append('%s = %d\n' % (enum_value.name,
                                              enum_value.number))

    for msg_desc in proto_file.message_type:
        code_gen_message(msg_desc, env)

    if is_gen:
        lua = Writer()
        write_header(lua)
        lua('local protobuf = require "protobuf"\n')
        for i in includes:
            lua('local %s_PB = require("%s_pb")\n' % (i.upper(), i))
        lua("module('%s_pb')\n" % env.filename)

        lua('\n\n')
        map(lua, env.descriptor)
        lua('\n')
        map(lua, env.context)
        lua('\n')
        env.message.sort()
        map(lua, env.message)
        lua('\n')
        map(lua, env.register)

        _files[env.filename+ '_pb.lua'] = lua.getvalue()
    env.exit_file()

def main():
    plugin_require_bin = sys.stdin.read()
    code_gen_req = plugin_pb2.CodeGeneratorRequest()
    code_gen_req.ParseFromString(plugin_require_bin)

    env = Env()
    for proto_file in code_gen_req.proto_file:
        code_gen_file(proto_file, env,
                proto_file.name in code_gen_req.file_to_generate)

    code_generated = plugin_pb2.CodeGeneratorResponse()
    for k in  _files:
        file_desc = code_generated.file.add()
        file_desc.name = k
        file_desc.content = _files[k]

    sys.stdout.write(code_generated.SerializeToString())

if __name__ == "__main__":
    main()

  修改protoc-gen-lua文件的内容为以上,即可。

  lua里使用的时候,复合字段是有值的,直接取即可

  local person = person_pb.Person()

  person.company.name = "xxx" --其中company为复合字段,也就是另一个类型,比如 company_pb.Company()

  直接用就可以了