自定义异步可插入协议 (代码记录)

// http_protocol.cc

#include "trident/glue/protocol_impl/http_protocol.h"
#include "base/logging.h"
#include <WinInet.h>
#include <ExDisp.h>

namespace trident {

HttpProtocol::HttpProtocol(IUnknown* pOuterUnknown)
    : reference_count_(0), 
      outer_unknown_(pOuterUnknown), 
      grf_BindF_(0),
      inner_unknown_(NULL) {
  inner_unknown_ = reinterpret_cast<IUnknown*>((INonDelegatingUnknown*)(this));  
  ZeroMemory(&bind_info_, sizeof(BINDINFO));
  bind_info_.cbSize = sizeof(BINDINFO);
}

HttpProtocol::~HttpProtocol() {
}

// INonDelegatingUnknown
STDMETHODIMP HttpProtocol::NonDelegatingQueryInterface(REFIID riid, void** ppvObject) {
  if(ppvObject == NULL){
    return E_INVALIDARG;
  }

  HRESULT result = E_NOINTERFACE;
  *ppvObject = NULL;
  NonDelegatingAddRef();
  if (riid == IID_IUnknown) {
    *ppvObject = static_cast<INonDelegatingUnknown*>(this);
  }else if(riid == IID_IInternetProtocolRoot) { 
    *ppvObject = static_cast<IInternetProtocolRoot*>(this);
  } else if (riid == IID_IInternetProtocol) {
    *ppvObject = static_cast<IInternetProtocol*>(this);
  } else if (riid == IID_IInternetProtocolEx) {                                   
    *ppvObject = static_cast<IInternetProtocolEx*>(this);
  } else if (riid == IID_IInternetProtocolInfo) {                                   
    *ppvObject = static_cast<IInternetProtocolInfo*>(this);
  }
  if(*ppvObject)
    result = S_OK;
  else
    NonDelegatingRelease();
  return result;
}

STDMETHODIMP_(ULONG) HttpProtocol::NonDelegatingAddRef() {
  return (ULONG)::InterlockedIncrement(&reference_count_);
}

STDMETHODIMP_(ULONG) HttpProtocol::NonDelegatingRelease() {
  ::InterlockedDecrement(&reference_count_);
  if (reference_count_ == 0) {
    delete this;
  }
  return reference_count_;
}

// IUnknown
STDMETHODIMP HttpProtocol::QueryInterface(REFIID riid, void** ppvObject) {
  if (outer_unknown_) {
    return outer_unknown_->QueryInterface(riid, ppvObject);
  } else {
    return inner_unknown_->QueryInterface(riid, ppvObject);
  }
}

STDMETHODIMP_(ULONG) HttpProtocol::AddRef() {
  if (outer_unknown_) {
    return outer_unknown_->AddRef();
  } else {
    return inner_unknown_->AddRef();
  }
}

STDMETHODIMP_(ULONG) HttpProtocol::Release() {
  if (outer_unknown_) {
    return outer_unknown_->Release();
  } else {
    return inner_unknown_->Release();
  }
}

// IInternetProtocolRoot , XP SP2及以下版本走这个接口
STDMETHODIMP HttpProtocol::Start(LPCWSTR url, IInternetProtocolSink* protocol_sink, IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) {

  if(bind_info == NULL || protocol_sink == NULL || url == NULL)
    return E_INVALIDARG;

  bind_url_ = GURL(url);

  spSink_ = protocol_sink;

  spBindinfo_ = bind_info;   

  spSink_->QueryInterface(IID_IServiceProvider, (void**)&spServiceProvider_);
  if(!spServiceProvider_)
    spBindinfo_->QueryInterface(IID_IServiceProvider, (void**)&spServiceProvider_);
  DCHECK(spServiceProvider_);                                             

  // BINDINFO
  //http://msdn.microsoft.com/en-us/library/ie/aa767897(v=vs.85).aspx
  //http://msdn.microsoft.com/en-us/library/ie/aa741006(v=vs.85).aspx#Handling_BINDINFO_St
  HRESULT result = spBindinfo_->GetBindInfo(&grf_BindF_, &bind_info_);
  DCHECK(result == S_OK);

  if( !bind_info_.dwCodePage )
    bind_info_.dwCodePage = ::GetACP();

  /*bind_info_->ReportProgress(BINDSTATUS_FINDINGRESOURCE, strData);
  bind_info_->ReportProgress(BINDSTATUS_CONNECTING, strData);
  bind_info_->ReportProgress(BINDSTATUS_SENDINGREQUEST, strData);
  bind_info_->ReportProgress(BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE, CAtlString(m_url.GetMimeType()));
  bind_info_->ReportData(BSCF_FIRSTDATANOTIFICATION, 0, bind_url_.GetDataLength());
  bind_info_->ReportData(BSCF_LASTDATANOTIFICATION | BSCF_DATAFULLYAVAILABLE, m_url.GetDataLength(), m_url.GetDataLength());*/


  return S_OK;

  return S_OK;
}

STDMETHODIMP HttpProtocol::Continue(PROTOCOLDATA* pProtocolData) {
  return S_OK;
}

// IE6/IE8下有断言,发现调用Terminate后还会调用Abort
STDMETHODIMP HttpProtocol::Abort(HRESULT reason, DWORD options) {
  return S_OK;
}

STDMETHODIMP HttpProtocol::Terminate(DWORD options) {
  return S_OK;
}

STDMETHODIMP HttpProtocol::Suspend() {
  return E_NOTIMPL;
}

STDMETHODIMP HttpProtocol::Resume() { 
  return E_NOTIMPL;
}

STDMETHODIMP HttpProtocol::Read(void* pv, ULONG size, ULONG* pcbRead) {
  return S_OK;
}

STDMETHODIMP HttpProtocol::Seek(LARGE_INTEGER move, DWORD origin, ULARGE_INTEGER* new_position) {  
  return S_OK;
}

STDMETHODIMP HttpProtocol::LockRequest(DWORD options) {
  has_lock_request_ = true;
  return S_OK;
}

STDMETHODIMP HttpProtocol::UnlockRequest() {
  has_lock_request_ = false;
  return S_OK;
}

// XP SP3及以上版本走这个接口
STDMETHODIMP HttpProtocol::StartEx(IUri* uri, IInternetProtocolSink* protocol_sink, IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) { 
  if(uri == NULL) {
    return E_INVALIDARG;
  }

  BSTR uri_URL = NULL;
  std::wstring url;
  uri->GetAbsoluteUri(&uri_URL);
  if (uri_URL != NULL) {
    url = uri_URL;
    ::SysFreeString(uri_URL);
  }
  uri->Release();

  return Start(url.c_str(), protocol_sink, bind_info, flags, reserved);
}

STDMETHODIMP  HttpProtocol::ParseUrl(LPCWSTR pwzUrl, PARSEACTION ParseAction, DWORD dwParseFlags, LPWSTR pwzResult,  
  DWORD cchResult,  DWORD *pcchResult, DWORD dwReserved) {
  return S_OK;
}

STDMETHODIMP  HttpProtocol::CombineUrl( LPCWSTR pwzBaseUrl,  LPCWSTR pwzRelativeUrl,  DWORD dwCombineFlags, LPWSTR pwzResult,
  DWORD cchResult,DWORD *pcchResult,DWORD dwReserved)  {
  return S_OK;
}

STDMETHODIMP  HttpProtocol::CompareUrl( LPCWSTR pwzUrl1,LPCWSTR pwzUrl2,DWORD dwCompareFlags) {
  return S_OK;
}

STDMETHODIMP  HttpProtocol::QueryInfo(LPCWSTR pwzUrl,  QUERYOPTION OueryOption, DWORD dwQueryFlags, LPVOID pBuffer, DWORD cbBuffer, 
  DWORD *pcbBuf, DWORD dwReserved)  {
   return S_OK;
}

std::wstring HttpProtocol::GetVerbStr() const {
  wchar_t* pszRes = NULL;
  switch (bind_info_.dwBindVerb)
  {
  case BINDVERB_GET      :
    pszRes = L"GET";
    break;
  case BINDVERB_POST     :
    pszRes = L"POST";
    break;
  case BINDVERB_PUT      :
    pszRes = L"PUT";
    break;
  case BINDVERB_CUSTOM   :
    pszRes = bind_info_.szCustomVerb;
    break;
  }
  DCHECK(pszRes);
  return pszRes;
}


bool HttpProtocol::GetDataToSend(char** lplpData, DWORD* pdwSize) const {

  if(bind_info_.dwBindVerb == BINDVERB_GET)
    return false;

  if (bind_info_.stgmedData.tymed == TYMED_HGLOBAL) {
    if(lplpData)
      *lplpData = (char*)bind_info_.stgmedData.hGlobal;
    if(pdwSize)
      *pdwSize = bind_info_.cbstgmedData;
    return true;
  } else {
    return false;
  }
}

}







<pre name="code" class="cpp">// http_protocol.h

#ifndef TRIDENT_PROTOCOL_HTTP_PROTOCOL_H_
#define TRIDENT_PROTOCOL_HTTP_PROTOCOL_H_

// 实现参考 Win2K 源码
// privateineturlmoniappcnet.cxx
// privateineturlmoniappcnethttp.cxx

#include <atlbase.h>
#include <urlmon.h>
#include <vector>
#include "base/basictypes.h"
#include "url/gurl.h"

namespace trident {

// COM组件聚合帮助接口
// 参考:http://msdn.microsoft.com/en-us/library/windows/desktop/dd390339(v=vs.85).aspx
struct INonDelegatingUnknown {
  STDMETHOD(NonDelegatingQueryInterface)(REFIID riid, void** ppvObject) = 0;
  STDMETHOD_(ULONG, NonDelegatingAddRef)() = 0;
  STDMETHOD_(ULONG, NonDelegatingRelease)() = 0;
};

class HttpProtocol : public INonDelegatingUnknown,
                     public IInternetProtocolEx, 
                     public IInternetProtocolInfo{
public:
  HttpProtocol(IUnknown* pOuterUnknown);
  virtual ~HttpProtocol();

public:

  // INonDelegatingUnknown
  // 只提供Protocol接口查询,不提供Sink接口查询
  STDMETHOD(NonDelegatingQueryInterface)(REFIID riid, void** ppvObject);
  STDMETHOD_(ULONG, NonDelegatingAddRef)();
  STDMETHOD_(ULONG, NonDelegatingRelease)();

  // IUnknown
  STDMETHOD(QueryInterface)(REFIID riid, void** ppvObject);
  STDMETHOD_(ULONG, AddRef)();
  STDMETHOD_(ULONG, Release)();

  // IInternetProtocolRoot
  STDMETHOD(Start)(LPCWSTR url, IInternetProtocolSink* protocol_sink, IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved);
  STDMETHOD(Continue)(PROTOCOLDATA* pProtocolData);
  STDMETHOD(Abort)(HRESULT reason, DWORD options);
  STDMETHOD(Terminate)(DWORD options);
  STDMETHOD(Suspend)();
  STDMETHOD(Resume)();

  // IInternetProtocol : public IInternetProtocolRoot
  STDMETHOD(Read)(void* pv, ULONG size, ULONG* pcbRead);
  STDMETHOD(Seek)(LARGE_INTEGER move, DWORD origin, ULARGE_INTEGER* new_position);
  STDMETHOD(LockRequest)(DWORD options);
  STDMETHOD(UnlockRequest)();

  // IInternetProtocolEx : public IInternetProtocol
  STDMETHOD(StartEx)(IUri* uri, IInternetProtocolSink* protocol_sink, IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved);

  // IInternetProtocolInfo
  STDMETHOD(ParseUrl)(LPCWSTR pwzUrl, PARSEACTION ParseAction, DWORD dwParseFlags, LPWSTR pwzResult,  DWORD cchResult,  DWORD *pcchResult, DWORD dwReserved) ;
  STDMETHOD(CombineUrl)( LPCWSTR pwzBaseUrl,  LPCWSTR pwzRelativeUrl,  DWORD dwCombineFlags, LPWSTR pwzResult,DWORD cchResult,DWORD *pcchResult,DWORD dwReserved) ;
  STDMETHOD(CompareUrl)( LPCWSTR pwzUrl1,LPCWSTR pwzUrl2,DWORD dwCompareFlags) ;
  STDMETHOD(QueryInfo)(LPCWSTR pwzUrl,  QUERYOPTION OueryOption, DWORD dwQueryFlags, LPVOID pBuffer, DWORD cbBuffer, DWORD *pcbBuf, DWORD dwReserved);

private:
  std::wstring GetVerbStr() const ;
  bool  GetDataToSend(char** lplpData, DWORD* pdwSize) const ;

private:

  volatile LONG reference_count_;

  IUnknown* outer_unknown_;

  IUnknown* inner_unknown_;

  CComPtr<IInternetProtocolSink> spSink_;

  CComPtr<IInternetBindInfo> spBindinfo_;

  CComPtr<IServiceProvider> spServiceProvider_;

  BINDINFO bind_info_;

  DWORD grf_BindF_;

  bool has_lock_request_;

  GURL bind_url_;

  DISALLOW_COPY_AND_ASSIGN(HttpProtocol);
};

}  //namespace trident

#endif  // TRIDENT_PROTOCOL_HTTP_PROTOCOL_H_







<pre name="code" class="cpp">// http protocol factory.cc

#include "trident/glue/protocol_impl/http_protocol_factory.h"
#include "base/logging.h"
#include "trident/glue/protocol_impl/http_protocol.h"

namespace trident {

HttpProtocolFactory::HttpProtocolFactory(bool is_https_protocol) : reference_count_(1) {
  HRESULT result = S_OK;
  if (is_https_protocol) {
    result = ::CoGetClassObject(CLSID_HttpSProtocol, CLSCTX_INPROC_SERVER, NULL, IID_IClassFactory, (void**)&origin_factory_);
  } else {
    result = ::CoGetClassObject(CLSID_HttpProtocol, CLSCTX_INPROC_SERVER, NULL, IID_IClassFactory, (void**)&origin_factory_);
  }
  DCHECK(result == S_OK);
  DCHECK(origin_factory_ != NULL);
}

HttpProtocolFactory::~HttpProtocolFactory() {
}

// IUnknown
STDMETHODIMP HttpProtocolFactory::QueryInterface(REFIID riid, void** ppvObject) {
  if (!ppvObject) {
    return E_INVALIDARG;
  }
  *ppvObject = NULL;
  HRESULT result = E_NOINTERFACE;
  if (riid == IID_IUnknown) {
    *ppvObject = static_cast<IUnknown*>(this);
  } else if (riid == IID_IClassFactory) {
    *ppvObject = static_cast<IClassFactory*>(this);
  }

  if (*ppvObject) {
    static_cast<IUnknown*>(*ppvObject)->AddRef();
    result = S_OK;
  }

  return result;
}

STDMETHODIMP_(ULONG) HttpProtocolFactory::AddRef() {
  return ::InterlockedIncrement(&reference_count_);
}

STDMETHODIMP_(ULONG) HttpProtocolFactory::Release() {
  ULONG count = ::InterlockedDecrement(&reference_count_);
  if (count == 0) {
    delete this;
    return 0;
  }

  return count;
}

// IClassFactory
STDMETHODIMP HttpProtocolFactory::CreateInstance(IUnknown* pUnkOuter, REFIID riid, void** ppvObject) {
  if (pUnkOuter && riid != IID_IUnknown) {
    return CLASS_E_NOAGGREGATION;
  }

  HttpProtocol* http_protocol = new HttpProtocol(pUnkOuter);
  if(http_protocol->NonDelegatingQueryInterface(riid, ppvObject) != S_OK) {
    delete http_protocol;
    *ppvObject = NULL;
    return E_NOINTERFACE;
  }else {
    return S_OK;
  }
}

STDMETHODIMP HttpProtocolFactory::LockServer(BOOL fLock) {
  if(fLock)
    AddRef();
  else
    Release();
  return S_OK;
}

}






// http protocol factroy .h

#ifndef TRIDENT_HTTP_PROTOCOL_FACTORY_H_
#define TRIDENT_HTTP_PROTOCOL_FACTORY_H_

#include <atlbase.h>
#include <Unknwn.h>
#include "base/basictypes.h"

namespace trident {

class HttpProtocolFactory : public IClassFactory {
public:

  explicit HttpProtocolFactory(bool is_https_protocol);

  // IUnknown
  STDMETHOD(QueryInterface)(REFIID riid, void** ppvObject);
  STDMETHOD_(ULONG, AddRef)();
  STDMETHOD_(ULONG, Release)();

  // IClassFactory
  STDMETHOD(CreateInstance)(IUnknown* pUnkOuter, REFIID riid, void** ppvObject);
  STDMETHOD(LockServer)(BOOL fLock);

private:
  virtual ~HttpProtocolFactory();

  volatile ULONG reference_count_;
  CComPtr<IClassFactory> origin_factory_;
  DISALLOW_IMPLICIT_CONSTRUCTORS(HttpProtocolFactory);
};

}

#endif  // TRIDENT_HTTP_PROTOCOL_FACTORY_H_