关于平台调用(P/Invoke)的回调特性在面向对象的应用

时间:2023-01-24 19:00:32

关于平台调用的相关知识在园子里面各位大牛已经总结得相当到位,大多博客描述的内容描述了如何使用托管程序调用非托管程序集(主要是由C写的dll),少量则涉及到了非托管程序集调用托管程序的回调内容。

事实上,合理使用平台调用的托管-非托管的回调特性,可以帮助我们使用非托管程序集时,更能遵循关注分离的原则,编写出具有高度抽象,低耦合的代码,本博客的目的便是根据前人的经验,摸索出出自己的一种更为面向对象的方式,使得托管与非托管的交互更加地面向对象,与大家分享,交流以下,以便相互借鉴,学习提高。

以应对现实世界中多态,抽象的应用场景——需要注意的是,本博客涉及到了一些简单的C++内容(比如类,纯虚基类的概念)。

在本人的工作中,遇到了这样的一个应用场景,我们由于一些内部技术人员组成的原因,需要使用C语言编写读取本地文件(在我所在的取证行业里,这里文件专业的名称为"镜像",由此往下,文件均称为镜像)的核心代码,之前的方案由另一位技术提出的:

我们使用了Windows API中的HANDLE,在上层调用的时候,将对应镜像的HANDLE传入, 由于.Net Framework的BCL(Basic Class Library)中的FileStream向上层开放了Win32的HANDLE特性(通过属性FileStream.Handle/SafeFileHandle(为防止释放所带来的逻辑问题,SafeFileHandle为Handle的Win32安全封装)获取),

我们可以很直接地实现如上的设计思路。

简单代码样本如下:

c部分:

//在代码的另外一个地方,定义了此方法(函数)的实现.
extern
"C" _declspec(dllexport) void Read(HANDLE handle);

c#部分:

 internal const string ReaderAssembly= "Reader.dll";
 [DllImport(ReaderAssembly, CharSet = CharSet.Auto, CallingConvention = CallingConvention.Cdecl)]
 internal static extern void Read(SafeFileHandle handle);

调用部分:

var fs = File.OpenRead(文件路径);
Read(fs.SafeFileHandle);

示例代码很简单,总结一下即为,C#部分将FileStream的Handle直接传入到非托管中,C语言部分将直接利用Windows API所提供的功能对HANDLE进行读取(写入的前提是打开文件的方法为可写,比如使用File.Open方法),文件指针的跳转等灵活的操作,无论是托管还是非托管部分,代码都十分简单,十分自♂由。

但是随后的一个需求(也不算随后的需求,事实上,该需求在开发之初就应该考虑)却暴露除了本方法所带来的一些限制——软件需要支持经过编码后的镜像,如果仍然像之前的方法一样,直接将编码后的镜像的HANDLE传入非托管,那么非托管部分需要判断镜像的编码格式,并自行解码,读取,但是这违背了关注分离的原则,Read部分应只关注镜像(解码后)数据本身的内容,所带来的坏处就是,Read的代码处理了一些本不应被自己关注的镜像编码的信息,如若添加任何一种编码,就需要更改该Read函数的内部实现,当然,又经验的C程序员可能说,将所有镜像编码相关的代码封装到一个代码组成,Read方法调用该组成不就好了吗?那么该组成对应的大致代码如下(当然,为求尽量简单,以下的代码形式从软件工程上而言是很粗糙的,存在相当多的地方可以优化,但是作为描述镜像读取的形式,应该是足够的):

 /// <summary>
        /// 
        /// </summary>
        /// <param name="handle">HANDLE</param>
        /// <param name="buffer">缓冲区</param>
        /// <param name="sizeToRead">需要读取的大小</param>
        /// <returns>实际读取的大小</returns>
        int ReadHandleWrapper(HANDLE handle, void* buffer, int sizeToRead) {
            //判断镜像类型;
            var res = DetectHandle(handle);
            //根据镜像类型进行读取;
            switch (res) {
                case 0:
                    //...
                    break;
                case 1:
                    //...
                    break;
                default:
                    //...
                    break;
            }
        }
        /// <summary>
        /// 检测镜像类型;
        /// </summary>
        /// <param name="handle"></param>
        /// <returns></returns>
        int DetectHandle(HANDLE handle) {
            ...
        }

确实,这样的确也能避免Read代码的膨胀,并且在Read之外的其它地方也能充分复用处理镜像编码的部分,如果没有其它特别的需求,这样的设计方式已经足够了。但是,针对我所在的项目场景,经过编码的镜像除了镜像本身的原数据外,往往还具有一些描述字段信息,托管世界需要拿到这些信息呈现出来(这并不是新需求,也是在本项目开发之初应该考虑进去的),那么现在,我们是否就应该考虑如何在上面一段代码中加入向托管代码提供元数据的描述信息的功能?由于镜像的编码类型具有多样性,那么我们在每加入一种类型的编码支持,就要直接或间接地修改该代码,长此以往,该部分的代码量将会变得非常大,代码中充满了各种判断,维护也是比较辛苦的。

如果你是一个对面向对象有足够了解的技术人员,你会发现,上面的代码都是面向过程的,什么是面向过程?以我个人的一点从面向对象的角度而言的浅薄理解,所谓面向过程,即实现就是定义,定义就是实现,面向过程趋于将不同的事物完全分开(即使从逻辑而言,他们具有相同的特点),作不同的定义,有多少种实现,便有多少种定义。在面向对象的做法中,我们将这些相同的特点提取出来,制定一个统一定义,在面向对象的语言中称为抽象,多态(当然,存在一些面向对象的语言因为种种原因是没有这些特点的,比如Python,这里不做评述,避免引战嫌疑),具体到代码的表现便是诸如interface,虚类等概念,我们可以为一个定义创造多种实现。

那么,针对我上述所提到以上问题,适用面向对象的做法,我们将镜像的判断,解码从Read部分的非托管代码中彻底分离,尝试将解码后的原数据内容抽象为一个定义,使Read方法不再直接使用HANDLE参数,而是使用镜像原数据的定义,至于具体传入的实现是怎样的,Read将不会关注,只需传入的原数据是什么。

当然,我这么说可能还是比较含糊,那么直接从代码入手或许会更直接,我们尝试将解码后的原数据内容抽象为一个定义,那么这个定义需同时能够在非托管,托管中使用,有没有现成的一个定义呢?根据我的经验,我认为没有什么比System.IO.Stream这样一个抽象类更适合作为我们的定义了——FileStream也是继承自Stream,那么我们可以尝试使用Stream处理上述问题——我们将所有的镜像编码(包括未编码)格式的解码过程抽象为一个Stream,根据具体镜像的格式,编写不同的实现,镜像的读写,指针跳转都将依赖于Stream实例,通过平台调用的回调特性,为非托管做一层Stream的适配,使其适配至非托管中我们所定义的一个Stream契约,非托管中Read也将依赖于该契约所在的项目(或者头文件)。这样我们便将镜像的定义和实现分开,利用抽象,多态的特性实现了关注分离的原则。

在我的设计中,包含了四个项目。

1.Contracts契约项目(非托管),用作定义非托管下的Stream,仅有一个头文件为Contracts.h。

2.StreamAdapter项目(非托管),用作实现一个托管Stream->非托管Stream的适配。

3.StreamAdapter项目(托管),通过平台调用项目2,使用回调的特性使用项目2中的方法。

4.Read(非托管)项目,引用了契约项目。

代码如下:

1.契约项目中Contracts.h中的Stream为System.IO.Stream的非托管版,我们在Contracts.h中编写了如下的一个虚基类:

#pragma once


class Stream {
public:
    //获取长度
    virtual __int64 GetLength() = 0;

    //获取位置;
    virtual __int64 GetPosition() = 0;

    //设定位置;(Seek);
    virtual void SetPosition(__int64 position) = 0;

    virtual bool CanRead() = 0;

    //读取;
    //参数lpBuffer:缓冲区;
    //参数nNumberOfBytesToRead:读取大小;
    //参数nPos:流的位置;
    //返回:实际读取大小;
    virtual bool Read(BYTE * lpBuffer, unsigned long nNumberOfBytesToRead, unsigned long *nRetSize, __int64 nPos) = 0;
    
    //是否可写;
    virtual bool CanWrite() = 0;

    //写入数据;
    //参数lpBuffer:缓冲区;
    //参数nNumberOfBytesToRead:写入大小;
    //参数nPos:流的位置;
    //返回:实际写入大小;
    virtual bool Write(BYTE* lpBuffer, unsigned long nNumberOfBytesToWrite, unsigned long *nRetSize, __int64 nPos) = 0;

    //关闭流;
    virtual void Close() = 0;
};

2..StreamAdapter项目(非托管),实现了一个Stream,所有的方法实现均是调用回调。

#include "../Contracts/Contracts.h"
#ifdef STREAMADAPTER_EXPORTS
#define XYZAPI __declspec(dllexport)
#else
#define XYZAPI __declspec(dllimport)
#endif


//数据源规范(RAW,VHD等);
class XYZAPI UnManagedStream : public Stream {
public:
    UnManagedStream(int s)
    {
            
    }
    //获取长度
    long long GetLength();

    void SetGetLengthFunc(long long(*getLengthFunc)());

    //获取位置;
    long long GetPosition();

    //设定位置;(Seek);
    void SetPosition(long long position);
    
    void SetPositionFunc(long long(*getPositionFunc)(), void(*setPositionFunc)(long long pos));

    bool CanRead();

    void SetCanReadFunc(bool(*canReadFunc)());

    //读取;
    //参数lpBuffer:缓冲区;
    //参数nNumberOfBytesToRead:读取大小;
    //参数nPos:流的位置;
    //返回:实际读取大小;
    bool Read(BYTE * lpBuffer, unsigned long nNumberOfBytesToRead, unsigned long *nRetSize, long long nPos);

    void SetReadFunc(int(*readFunc)(BYTE *lpBuffer, int nNumberOfBytesToRead));

    //是否可写;
    bool CanWrite();

    void SetCanWriteFunc(bool(*canWriteFunc)());

    //写入数据;
    //参数lpBuffer:缓冲区;
    //参数nNumberOfBytesToRead:写入大小;
    //参数nPos:流的位置;
    //返回:实际写入大小;
    bool Write(BYTE* lpBuffer, unsigned long nNumberOfBytesToWrite, unsigned long *nRetSize, long long nPos);

    void SetWriteFunc(int(*writeFunc)(BYTE* lpBuffer, int nNumberOfBytesToWrite));

    //关闭流;
    void Close();

private:
    //typedef char* (*__cdecl AddCallBack)(const char* a, const char* b);
    long long(*_getLengthFunc)();

    long long(*_getPositionFunc)();

    void(*_setPositionFunc)(long long pos);

    bool(*_canReadFunc)();

    int(*_readFunc)(BYTE *lpBuffer, int nNumberOfBytesToRead);

    int(*_writeFunc)(BYTE* lpBuffer, int nNumberOfBytesToWrite);

    //是否可写;
    bool(*_canWriteFunc)();
};

 
 

并向托管中开放了所有回调的驻留读写

 

//获取非托管流;
extern "C" _declspec(dllexport) UnManagedStream* _cdecl CreateUnManagedStream() {
    return new UnManagedStream(2);
}

extern "C" _declspec(dllexport) void SetGetLengthFunc(UnManagedStream* stream, long long(*getLengthFunc)()) {
    if (stream == nullptr) {
        return;
    }
    stream->SetGetLengthFunc(getLengthFunc);
}

extern "C" _declspec(dllexport) void SetPositionFunc(UnManagedStream* stream, long long(*getPositionFunc)(), void(*setPositionFunc)(long long pos)) {
    if (stream == nullptr) {
        return;
    }
    stream->SetPositionFunc(getPositionFunc, setPositionFunc);
}

extern "C" _declspec(dllexport) void SetCanReadFunc(UnManagedStream* stream, bool(*canReadFunc)()) {
    if (stream == nullptr) {
        return;
    }

    stream->SetCanReadFunc(canReadFunc);
}

extern "C" _declspec(dllexport) void SetCanWriteFunc(UnManagedStream* stream, bool(*canWriteFunc)()) {
    if (stream == nullptr) {
        return;
    }

    stream->SetCanWriteFunc(canWriteFunc);
}

extern "C" _declspec(dllexport) void SetWriteFunc(UnManagedStream* stream, int(*writeFunc)(BYTE* lpBuffer, int nNumberOfBytesToWrite)) {
    if (stream == nullptr) {
        return;
    }

    stream->SetWriteFunc(writeFunc);
}

extern "C" _declspec(dllexport) void SetReadFunc(UnManagedStream* stream, int(*readFunc)(BYTE *lpBuffer, int nNumberOfBytesToRead)) {
    if (stream == nullptr) {
        return;
    }

    stream->SetReadFunc(readFunc);
}

extern "C" _declspec(dllexport) void CloseStream(UnManagedStream* stream) {
    if (stream == nullptr) {
        return;
    }

    stream->Close();
    delete stream;
}


extern "C" _declspec(dllexport) long long GetStreamLength(Stream *stream) {
    if (stream == nullptr) {
        return -1;
    }

    return stream->GetLength();
}

extern "C" _declspec(dllexport) long long GetStreamPosition(Stream *stream) {
    if (stream == nullptr) {
        return -1;
    }

    stream->GetPosition();
}

extern "C" _declspec(dllexport) void SetStreamPosition(Stream *stream, long long pos) {
    if (stream == nullptr) {
        return ;
    }

    stream->SetPosition(pos);
}

3.StreamAdpater(托管)项目中编写了这样一个适配器类,(由于非托管和托管内存管理上的差异,我们需小心避免垃圾回收器对非托管的回调影响,详见注释)

 /// <summary>
    /// 非托管流适配器,可映射任意流至非托管环境下的一个UnmanagedStream对象;
    /// <!--本类实现了IDisposable,实例对象保存在Static队列中,当且仅当在调用了Dispose()后实例才可能被回收-->
    /// </summary>
    public partial class UnmanagedStreamAdapter : IDisposable {
        public UnmanagedStreamAdapter(Stream stream) {
            OriStream = stream ?? throw new ArgumentNullException(nameof(stream));
            StreamPtr = CreateUnManagedStream();

            _instances.Add(this);

            InitializePtr();
        }

        //在向非托管环境传递委托实例时,如无其它托管引用该实例,
        //由于非托管环境在托管堆标记委托实例的引用,
        //则委托实例可能在被下一次垃圾回收时被回收,
        //所以此处需要单独引用对应的该委托实例,已确保委托实例不会被不正确地回收。
        List<Delegate> delegates = new List<Delegate>();
        
        private long OnGetPos() => OriStream.Position;
        private void OnSetPos(long pos) {
#if DEBUG
            if(pos >= 13878934528 && OriStream.Length == 13878934528) {
                
            }
#endif
            if(pos > OriStream.Length) {
                LoggerService.WriteCallerLine($"{nameof(pos)} out of range.({nameof(pos)}:{pos},available length:{OriStream.Length}");
                OriStream.Position = pos;
                return;
            }
            else if(pos < 0){
                LoggerService.WriteCallerLine($"{nameof(pos)} can't be less than zero.({nameof(pos)}:{pos},available length:{OriStream.Length}");
                return;
            }
            
            OriStream.Position = pos;
        }
        private long OnGetLength() => OriStream.Length;
        private bool OnCanWrite() => OriStream.CanWrite;
        private bool OnCanRead() => OriStream.CanRead;

        private int OnWrite(IntPtr buffer, int count) {
            if (writeBuffer.Length < count) {
                writeBuffer = new byte[count];
            }

            var oldPos = OriStream.Position;
            OriStream.Write(writeBuffer, 0, count);
            Marshal.Copy(buffer, writeBuffer, 0, count);
            return (int)(OriStream.Position - oldPos);
        }

        private int OnRead(IntPtr buffer, int count) {
            if (this.readBuffer.Length < count) {
                this.readBuffer = new byte[count];
            }

            var readCount = OriStream.Read(readBuffer, 0, count);

#if DEBUG
            if (count == 30380032) {

            }
            
            if(count != readCount && OriStream.Length == 13878934528) {

            }
#endif

            Marshal.Copy(readBuffer, 0, buffer, readCount);
            return readCount;
        }



        /// <summary>
        /// 初始化非托管委托指针(函数指针);
        /// </summary>
        private void InitializePtr() {
            //位置委托;
            SetInt64Delegate setPos = OnSetPos;
            GetInt64Delegate getPos = OnGetPos;
            delegates.Add(setPos);
            delegates.Add(getPos);

            SetPositionFunc(
                StreamPtr,
                getPos,
                setPos
            );

            //长度委托;
            GetInt64Delegate getLength = OnGetLength;
            delegates.Add(getLength);
            SetGetLengthFunc(StreamPtr, getLength);

            //可读/写委托;
            GetBoolDelegate canRead = OnCanRead;
            GetBoolDelegate canWrite = OnCanWrite;
            delegates.Add(canRead);
            delegates.Add(canWrite);
            SetCanReadFunc(StreamPtr, canRead);
            SetCanWriteFunc(StreamPtr, canWrite);
            
            //写入委托;
            WriteDelegate write = OnWrite;
            delegates.Add(write);
            SetWriteFunc(StreamPtr, write);
            
            //读取委托;
            ReadDelegate read = OnRead;
            delegates.Add(read);
            SetReadFunc(StreamPtr, read);
        }

        private byte[] readBuffer = new byte[4096];
        private byte[] writeBuffer = new byte[4096];

        public Stream OriStream { get; }

        private IntPtr _streamPtr;
        public IntPtr StreamPtr {
            get {
                if (_disposed) {
                    throw new ObjectDisposedException(nameof(UnmanagedStreamAdapter));
                }
                return _streamPtr;
            }
            private set => _streamPtr = value;
        }

        private bool _disposed;

        /// <summary>
        /// 当且仅当在调用了Dispose()后实例才可能被回收
        /// </summary>
        public void Dispose() {
            CloseStream(StreamPtr);
            StreamPtr = IntPtr.Zero;
            _disposed = true;
            if (_instances.Contains(this)) {
                _instances.Remove(this);
            }
        }

        //所有实例的引用必须存放在此静态示例列表;避免本单位被GC回收后;
        //非托管环境意外进行了调用,引发了非法访问内存的错误;
        //只有在调用Dispose方法后,才可解除引用,使得垃圾回收按照预期正常执行对实例的回收;
        private static List<UnmanagedStreamAdapter> _instances = new List<UnmanagedStreamAdapter>();

        ~UnmanagedStreamAdapter() {
            if (!_disposed) {
                Dispose();
            }
        }
    }

    public partial class UnmanagedStreamAdapter {
        private const string streamAsm = "StreamAdapter.dll";

        //[return: MarshalAs(UnmanagedType.I8)]
        [UnmanagedFunctionPointer(CallingConvention.Cdecl)]
        private delegate long GetInt64Delegate();

        //[return: MarshalAs(UnmanagedType.I8)]
        [UnmanagedFunctionPointer(CallingConvention.Cdecl)]
        private delegate void SetInt64Delegate(long int64);

        //[return: MarshalAs(UnmanagedType.Bool)]
        [UnmanagedFunctionPointer(CallingConvention.Cdecl)]
        private delegate bool GetBoolDelegate();

        [UnmanagedFunctionPointer(CallingConvention.Cdecl)]
        private delegate int WriteDelegate(IntPtr data, int count);

        [UnmanagedFunctionPointer(CallingConvention.Cdecl)]
        private delegate int ReadDelegate(IntPtr buffer, int count);

        [DllImport(streamAsm, CharSet = CharSet.Auto, CallingConvention = CallingConvention.Cdecl)]
        private static extern IntPtr CreateUnManagedStream();

        [DllImport(streamAsm, CharSet = CharSet.Auto, CallingConvention = CallingConvention.Cdecl)]
        private static extern void SetGetLengthFunc(IntPtr stream, [MarshalAs(UnmanagedType.FunctionPtr)] GetInt64Delegate getLengthFunc);

        [DllImport(streamAsm, CharSet = CharSet.Auto, CallingConvention = CallingConvention.Cdecl)]
        private static extern void SetPositionFunc(IntPtr stream, [MarshalAs(UnmanagedType.FunctionPtr)] GetInt64Delegate getPositionFunc,
            [MarshalAs(UnmanagedType.FunctionPtr)]SetInt64Delegate setPositionFunc);

        [DllImport(streamAsm, CharSet = CharSet.Auto, CallingConvention = CallingConvention.Cdecl)]
        private static extern void SetCanReadFunc(IntPtr stream, [MarshalAs(UnmanagedType.FunctionPtr)]GetBoolDelegate canReadFunc);

        [DllImport(streamAsm, CharSet = CharSet.Auto, CallingConvention = CallingConvention.Cdecl)]
        private static extern void SetCanWriteFunc(IntPtr stream, GetBoolDelegate canWriteFunc);

        [DllImport(streamAsm, CharSet = CharSet.Auto, CallingConvention = CallingConvention.Cdecl)]
        private static extern void SetWriteFunc(IntPtr stream, [MarshalAs(UnmanagedType.FunctionPtr)]WriteDelegate writeFunc);

        [DllImport(streamAsm, CharSet = CharSet.Auto, CallingConvention = CallingConvention.Cdecl)]
        private static extern void SetReadFunc(IntPtr stream, [MarshalAs(UnmanagedType.FunctionPtr)]ReadDelegate readFunc);

        [DllImport(streamAsm, CharSet = CharSet.Auto, CallingConvention = CallingConvention.Cdecl)]
        private static extern void CloseStream(IntPtr stream);

    }

4.Read(非托管)项目中Read方法只需更改为如下方式即可.

#include "Contracts/Contracts.h"
//在代码的另外一个地方,定义了此方法(函数)的实现.
extern "C" _declspec(dllexport) void Read(Stream* stream);

 

致此,虽然这种方法在设计实现时工作量较大,但是个人很好地实现了托管-非托管中关注分离的原则,提升了可维护性。

当然,园子里如果有大牛有更简洁,更优秀的方法能够达到同样的目的,请不要吝啬你的字数,欢迎在评论区指正,大家一同讨论学习。

我的笔力实在是拙计见肘,不能很准确地表达我的真实想法,如果有同道中人能够理解我的意思,并能够帮助我更好地表达,我将不胜感激。