CEL

Public API Reference

propclass/neuralnet.h

00001 /*
00002     Crystal Space Entity Layer
00003     Copyright (C) 2007 by Jorrit Tyberghein
00004 
00005     Neural Network Property Class
00006     Copyright (C) 2007 by Mat Sutcliffe
00007 
00008     This library is free software; you can redistribute it and/or
00009     modify it under the terms of the GNU Library General Public
00010     License as published by the Free Software Foundation; either
00011     version 2 of the License, or (at your option) any later version.
00012 
00013     This library is distributed in the hope that it will be useful,
00014     but WITHOUT ANY WARRANTY; without even the implied warranty of
00015     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00016     Library General Public License for more details.
00017 
00018     You should have received a copy of the GNU Library General Public
00019     License along with this library; if not, write to the Free
00020     Software Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00021 */
00022 
00023 #ifndef __CEL_PF_NEURALNET__
00024 #define __CEL_PF_NEURALNET__
00025 
00026 #include "cstypes.h"
00027 #include "csutil/scf.h"
00028 #include "csutil/refcount.h"
00029 #include "csutil/array.h"
00030 #include "csgeom/math.h"
00031 
00032 #include "physicallayer/datatype.h"
00033 
00034 class celNNActivationFunc;
00035 
00049 struct iCelNNWeights : public virtual iBase
00050 { 
00051   SCF_INTERFACE(iCelNNWeights, 0, 0, 1);
00052 
00054   virtual csArray< csArray< csArray<float> > >& Data() = 0;
00055 
00057   virtual const csArray< csArray< csArray<float> > >& Data() const = 0;
00058 
00060   virtual csPtr<iCelNNWeights> Clone() const = 0;
00061 };
00062 
00096 struct iPcNeuralNet : public virtual iBase
00097 {
00098   SCF_INTERFACE(iPcNeuralNet, 0, 0, 1);
00099 
00104   virtual void SetSize(size_t inputs, size_t outputs, size_t layers) = 0;
00105 
00116   virtual void SetComplexity(const char *name) = 0;
00117 
00124   virtual void SetLayerSizes(const csArray<size_t> &sizes) = 0;
00125 
00129   virtual void SetActivationFunc(celNNActivationFunc *) = 0;
00130 
00136   virtual bool Validate() = 0;
00137 
00141   virtual void SetInput(size_t index, const celData &value) = 0;
00142 
00146   virtual const celData& GetOutput(size_t index) const = 0;
00147 
00151   virtual void SetInputs(const csArray<celData> &values) = 0;
00152 
00156   virtual const csArray<celData>& GetOutputs() const = 0;
00157 
00162   virtual void Process() = 0;
00163 
00167   virtual csPtr<iCelNNWeights> CreateEmptyWeights() const = 0;
00168 
00172   virtual void GetWeights(iCelNNWeights *out) const = 0;
00173 
00177   virtual bool SetWeights(const iCelNNWeights *in) = 0;
00178 
00182   virtual bool CacheWeights(const char *scope, uint32 id) const = 0;
00183 
00187   virtual bool LoadCachedWeights(const char *scope, uint32 id) = 0;
00188 };
00189 
00203 class celNNActivationFunc : public virtual csRefCount
00204 {
00205 public:
00207   virtual void Function(celData &data) = 0;
00208 
00210   virtual celDataType GetDataType() = 0;
00211 
00213   virtual ~celNNActivationFunc() {}
00214 
00215 protected:
00227   template <typename T>
00228   static const T& GetFrom(const celData &input);
00229 
00240   template <typename T>
00241   static celDataType DataType();
00242 };
00243 
00244 template<>
00245 inline const float& celNNActivationFunc::GetFrom<float>(const celData &input)
00246 {
00247   return input.value.f;
00248 }
00249 template<>
00250 inline const int8& celNNActivationFunc::GetFrom<int8>(const celData &input)
00251 {
00252   return input.value.b;
00253 }
00254 template<>
00255 inline const uint8& celNNActivationFunc::GetFrom<uint8>(const celData &input)
00256 {
00257   return input.value.ub;
00258 }
00259 template<>
00260 inline const int16& celNNActivationFunc::GetFrom<int16>(const celData &input)
00261 {
00262   return input.value.w;
00263 }
00264 template<>
00265 inline const uint16& celNNActivationFunc::GetFrom<uint16>(const celData &input)
00266 {
00267   return input.value.uw;
00268 }
00269 template<>
00270 inline const int32& celNNActivationFunc::GetFrom<int32>(const celData &input)
00271 {
00272   return input.value.l;
00273 }
00274 template<>
00275 inline const uint32& celNNActivationFunc::GetFrom<uint32>(const celData &input)
00276 {
00277   return input.value.ul;
00278 }
00279 
00280 template<>
00281 inline celDataType celNNActivationFunc::DataType<int8>()
00282 {
00283   return CEL_DATA_BYTE;
00284 }
00285 template<>
00286 inline celDataType celNNActivationFunc::DataType<uint8>()
00287 {
00288   return CEL_DATA_UBYTE;
00289 }
00290 template<>
00291 inline celDataType celNNActivationFunc::DataType<int16>()
00292 {
00293   return CEL_DATA_WORD;
00294 }
00295 template<>
00296 inline celDataType celNNActivationFunc::DataType<uint16>()
00297 {
00298   return CEL_DATA_UWORD;
00299 }
00300 template<>
00301 inline celDataType celNNActivationFunc::DataType<int32>()
00302 {
00303   return CEL_DATA_LONG;
00304 }
00305 template<>
00306 inline celDataType celNNActivationFunc::DataType<uint32>()
00307 {
00308   return CEL_DATA_ULONG;
00309 }
00310 template<>
00311 inline celDataType celNNActivationFunc::DataType<float>()
00312 {
00313   return CEL_DATA_FLOAT;
00314 }
00315 
00327 template <typename T>
00328 class celNopActivationFunc : public celNNActivationFunc
00329 {
00330 public:
00331   virtual void Function(celData &data) {}
00332   virtual celDataType GetDataType() { return DataType<T>(); }
00333   virtual ~celNopActivationFunc() {}
00334 };
00335 
00347 template <typename T>
00348 class celStepActivationFunc : public celNNActivationFunc
00349 {
00350 public:
00351   virtual void Function(celData &data)
00352   {
00353     const T &val = GetFrom<T>(data);
00354     data.Set(T (val > 1 ? 1 : 0));
00355   }
00356   virtual celDataType GetDataType() { return DataType<T>(); }
00357   virtual ~celStepActivationFunc() {}
00358 };
00359 
00371 template <typename T>
00372 class celLogActivationFunc : public celNNActivationFunc
00373 {
00374 public:
00375   virtual void Function(celData &data)
00376   {
00377     const T &val = GetFrom<T>(data);
00378     double e_v = log(fabs((double) val)); // log may return not-a-number
00379     data.Set((T) (csNormal(e_v) ? e_v : 0.0));
00380   }
00381   virtual celDataType GetDataType() { return DataType<T>(); }
00382   virtual ~celLogActivationFunc() {}
00383 };
00384 
00396 template <typename T>
00397 class celAtanActivationFunc : public celNNActivationFunc
00398 {
00399 public:
00400   virtual void Function(celData &data)
00401   {
00402     const T &val = GetFrom<T>(data);
00403     data.Set((T) atan((double) val));
00404   }
00405   virtual celDataType GetDataType() { return DataType<T>(); }
00406   virtual ~celAtanActivationFunc() {}
00407 };
00408 
00421 template <typename T>
00422 class celTanhActivationFunc : public celNNActivationFunc
00423 {
00424 public:
00425   virtual void Function(celData &data)
00426   {
00427     const T &val = GetFrom<T>(data);
00428     data.Set((T) tanh((double) val));
00429   }
00430   virtual celDataType GetDataType() { return DataType<T>(); }
00431   virtual ~celTanhActivationFunc() {}
00432 };
00433 
00445 template <typename T>
00446 class celExpActivationFunc : public celNNActivationFunc
00447 {
00448 public:
00449   virtual void Function(celData &data)
00450   {
00451     const T &val = GetFrom<T>(data);
00452     double e_v = exp((double) val); // exp may return infinite
00453     data.Set((T) (csNormal(e_v) ? e_v : 0.0));
00454   }
00455   virtual celDataType GetDataType() { return DataType<T>(); }
00456   virtual ~celExpActivationFunc() {}
00457 };
00458 
00468 class celInvActivationFunc : public celNNActivationFunc
00469 {
00470 public:
00471   virtual void Function(celData &data)
00472   {
00473     const float &val = GetFrom<float>(data);
00474     data.Set(1.0f / val);
00475   }
00476   virtual celDataType GetDataType() { return CEL_DATA_FLOAT; }
00477   virtual ~celInvActivationFunc() {}
00478 };
00479 
00491 template <typename T>
00492 class celSqrActivationFunc : public celNNActivationFunc
00493 {
00494 public:
00495   virtual void Function(celData &data)
00496   {
00497     const T &val = GetFrom<T>(data);
00498     data.Set(val * val);
00499   }
00500   virtual celDataType GetDataType() { return DataType<T>(); }
00501   virtual ~celSqrActivationFunc() {}
00502 };
00503 
00515 template <typename T>
00516 class celGaussActivationFunc : public celNNActivationFunc
00517 {
00518 public:
00519   virtual void Function(celData &data)
00520   {
00521     const T &val = GetFrom<T>(data);
00522     data.Set((T) exp((double) -(val * val)));
00523   }
00524   virtual celDataType GetDataType() { return DataType<T>(); }
00525   virtual ~celGaussActivationFunc() {}
00526 };
00527 
00539 template <typename T>
00540 class celSinActivationFunc : public celNNActivationFunc
00541 {
00542 public:
00543   virtual void Function(celData &data)
00544   {
00545     const T &val = GetFrom<T>(data);
00546     data.Set((T) sin((double) val));
00547   }
00548   virtual celDataType GetDataType() { return DataType<T>(); }
00549   virtual ~celSinActivationFunc() {}
00550 };
00551 
00563 template <typename T>
00564 class celCosActivationFunc : public celNNActivationFunc
00565 {
00566 public:
00567   virtual void Function(celData &data)
00568   {
00569     const T &val = GetFrom<T>(data);
00570     data.Set((T) cos((double) val));
00571   }
00572   virtual celDataType GetDataType() { return DataType<T>(); }
00573   virtual ~celCosActivationFunc() {}
00574 };
00575 
00587 template <typename T>
00588 class celElliottActivationFunc : public celNNActivationFunc
00589 {
00590 public:
00591   virtual void Function(celData &data)
00592   {
00593     const T &val = GetFrom<T>(data);
00594     data.Set(val / (1 + (T)fabs((double) val)));
00595   }
00596   virtual celDataType GetDataType() { return DataType<T>(); }
00597   virtual ~celElliottActivationFunc() {}
00598 };
00599 
00611 template <typename T>
00612 class celSigActivationFunc : public celNNActivationFunc
00613 {
00614 public:
00615   virtual void Function(celData &data)
00616   {
00617     const T &val = GetFrom<T>(data);
00618     data.Set(T(1) / (T)(1 + exp((double) -val)));
00619   }
00620   virtual celDataType GetDataType() { return DataType<T>(); }
00621   virtual ~celSigActivationFunc() {}
00622 };
00623 
00624 #endif // __CEL_PF_NEURALNET__
00625 

Generated for CEL: Crystal Entity Layer 1.2 by doxygen 1.4.7