propclass/neuralnet.h
00001 /* 00002 Crystal Space Entity Layer 00003 Copyright (C) 2007 by Jorrit Tyberghein 00004 00005 Neural Network Property Class 00006 Copyright (C) 2007 by Mat Sutcliffe 00007 00008 This library is free software; you can redistribute it and/or 00009 modify it under the terms of the GNU Library General Public 00010 License as published by the Free Software Foundation; either 00011 version 2 of the License, or (at your option) any later version. 00012 00013 This library is distributed in the hope that it will be useful, 00014 but WITHOUT ANY WARRANTY; without even the implied warranty of 00015 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00016 Library General Public License for more details. 00017 00018 You should have received a copy of the GNU Library General Public 00019 License along with this library; if not, write to the Free 00020 Software Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00021 */ 00022 00023 #ifndef __CEL_PF_NEURALNET__ 00024 #define __CEL_PF_NEURALNET__ 00025 00026 #include "cstypes.h" 00027 #include "csutil/scf.h" 00028 #include "csutil/refcount.h" 00029 #include "csutil/array.h" 00030 #include "csgeom/math.h" 00031 00032 #include "physicallayer/datatype.h" 00033 00034 class celNNActivationFunc; 00035 00049 struct iCelNNWeights : public virtual iBase 00050 { 00051 SCF_INTERFACE(iCelNNWeights, 0, 0, 1); 00052 00054 virtual csArray< csArray< csArray<float> > >& Data() = 0; 00055 00057 virtual const csArray< csArray< csArray<float> > >& Data() const = 0; 00058 00060 virtual csPtr<iCelNNWeights> Clone() const = 0; 00061 }; 00062 00095 struct iPcNeuralNet : public virtual iBase 00096 { 00097 SCF_INTERFACE(iPcNeuralNet, 0, 0, 1); 00098 00103 virtual void SetSize(size_t inputs, size_t outputs, size_t layers) = 0; 00104 00115 virtual void SetComplexity(const char *name) = 0; 00116 00123 virtual void SetLayerSizes(const csArray<size_t> &sizes) = 0; 00124 00128 virtual void SetActivationFunc(celNNActivationFunc *) = 0; 00129 00135 virtual bool Validate() = 0; 00136 00140 virtual void SetInput(size_t index, const celData &value) = 0; 00141 00145 virtual const celData& GetOutput(size_t index) const = 0; 00146 00150 virtual void SetInputs(const csArray<celData> &values) = 0; 00151 00155 virtual const csArray<celData>& GetOutputs() const = 0; 00156 00161 virtual void Process() = 0; 00162 00166 virtual csPtr<iCelNNWeights> CreateEmptyWeights() const = 0; 00167 00171 virtual void GetWeights(iCelNNWeights *out) const = 0; 00172 00176 virtual bool SetWeights(const iCelNNWeights *in) = 0; 00177 00181 virtual bool CacheWeights(const char *scope, uint32 id) const = 0; 00182 00186 virtual bool LoadCachedWeights(const char *scope, uint32 id) = 0; 00187 }; 00188 00202 class celNNActivationFunc : public virtual csRefCount 00203 { 00204 public: 00206 virtual void Function(celData &data) = 0; 00207 00209 virtual celDataType GetDataType() = 0; 00210 00212 virtual ~celNNActivationFunc() {} 00213 00214 protected: 00226 template <typename T> 00227 static const T& GetFrom(const celData &input); 00228 00239 template <typename T> 00240 static celDataType DataType(); 00241 }; 00242 00243 template<> 00244 inline const float& celNNActivationFunc::GetFrom<float>(const celData &input) 00245 { 00246 return input.value.f; 00247 } 00248 template<> 00249 inline const int8& celNNActivationFunc::GetFrom<int8>(const celData &input) 00250 { 00251 return input.value.b; 00252 } 00253 template<> 00254 inline const uint8& celNNActivationFunc::GetFrom<uint8>(const celData &input) 00255 { 00256 return input.value.ub; 00257 } 00258 template<> 00259 inline const int16& celNNActivationFunc::GetFrom<int16>(const celData &input) 00260 { 00261 return input.value.w; 00262 } 00263 template<> 00264 inline const uint16& celNNActivationFunc::GetFrom<uint16>(const celData &input) 00265 { 00266 return input.value.uw; 00267 } 00268 template<> 00269 inline const int32& celNNActivationFunc::GetFrom<int32>(const celData &input) 00270 { 00271 return input.value.l; 00272 } 00273 template<> 00274 inline const uint32& celNNActivationFunc::GetFrom<uint32>(const celData &input) 00275 { 00276 return input.value.ul; 00277 } 00278 00279 template<> 00280 inline celDataType celNNActivationFunc::DataType<int8>() 00281 { 00282 return CEL_DATA_BYTE; 00283 } 00284 template<> 00285 inline celDataType celNNActivationFunc::DataType<uint8>() 00286 { 00287 return CEL_DATA_UBYTE; 00288 } 00289 template<> 00290 inline celDataType celNNActivationFunc::DataType<int16>() 00291 { 00292 return CEL_DATA_WORD; 00293 } 00294 template<> 00295 inline celDataType celNNActivationFunc::DataType<uint16>() 00296 { 00297 return CEL_DATA_UWORD; 00298 } 00299 template<> 00300 inline celDataType celNNActivationFunc::DataType<int32>() 00301 { 00302 return CEL_DATA_LONG; 00303 } 00304 template<> 00305 inline celDataType celNNActivationFunc::DataType<uint32>() 00306 { 00307 return CEL_DATA_ULONG; 00308 } 00309 template<> 00310 inline celDataType celNNActivationFunc::DataType<float>() 00311 { 00312 return CEL_DATA_FLOAT; 00313 } 00314 00326 template <typename T> 00327 class celNopActivationFunc : public celNNActivationFunc 00328 { 00329 public: 00330 virtual void Function(celData &data) {} 00331 virtual celDataType GetDataType() { return DataType<T>(); } 00332 virtual ~celNopActivationFunc() {} 00333 }; 00334 00346 template <typename T> 00347 class celStepActivationFunc : public celNNActivationFunc 00348 { 00349 public: 00350 virtual void Function(celData &data) 00351 { 00352 const T &val = GetFrom<T>(data); 00353 data.Set(T (val > 1 ? 1 : 0)); 00354 } 00355 virtual celDataType GetDataType() { return DataType<T>(); } 00356 virtual ~celStepActivationFunc() {} 00357 }; 00358 00370 template <typename T> 00371 class celLogActivationFunc : public celNNActivationFunc 00372 { 00373 public: 00374 virtual void Function(celData &data) 00375 { 00376 const T &val = GetFrom<T>(data); 00377 double e_v = log(fabs((double) val)); // log may return not-a-number 00378 data.Set((T) (csNormal(e_v) ? e_v : 0.0)); 00379 } 00380 virtual celDataType GetDataType() { return DataType<T>(); } 00381 virtual ~celLogActivationFunc() {} 00382 }; 00383 00395 template <typename T> 00396 class celAtanActivationFunc : public celNNActivationFunc 00397 { 00398 public: 00399 virtual void Function(celData &data) 00400 { 00401 const T &val = GetFrom<T>(data); 00402 data.Set((T) atan((double) val)); 00403 } 00404 virtual celDataType GetDataType() { return DataType<T>(); } 00405 virtual ~celAtanActivationFunc() {} 00406 }; 00407 00420 template <typename T> 00421 class celTanhActivationFunc : public celNNActivationFunc 00422 { 00423 public: 00424 virtual void Function(celData &data) 00425 { 00426 const T &val = GetFrom<T>(data); 00427 data.Set((T) tanh((double) val)); 00428 } 00429 virtual celDataType GetDataType() { return DataType<T>(); } 00430 virtual ~celTanhActivationFunc() {} 00431 }; 00432 00444 template <typename T> 00445 class celExpActivationFunc : public celNNActivationFunc 00446 { 00447 public: 00448 virtual void Function(celData &data) 00449 { 00450 const T &val = GetFrom<T>(data); 00451 double e_v = exp((double) val); // exp may return infinite 00452 data.Set((T) (csNormal(e_v) ? e_v : 0.0)); 00453 } 00454 virtual celDataType GetDataType() { return DataType<T>(); } 00455 virtual ~celExpActivationFunc() {} 00456 }; 00457 00467 class celInvActivationFunc : public celNNActivationFunc 00468 { 00469 public: 00470 virtual void Function(celData &data) 00471 { 00472 const float &val = GetFrom<float>(data); 00473 data.Set(1.0f / val); 00474 } 00475 virtual celDataType GetDataType() { return CEL_DATA_FLOAT; } 00476 virtual ~celInvActivationFunc() {} 00477 }; 00478 00490 template <typename T> 00491 class celSqrActivationFunc : public celNNActivationFunc 00492 { 00493 public: 00494 virtual void Function(celData &data) 00495 { 00496 const T &val = GetFrom<T>(data); 00497 data.Set(val * val); 00498 } 00499 virtual celDataType GetDataType() { return DataType<T>(); } 00500 virtual ~celSqrActivationFunc() {} 00501 }; 00502 00514 template <typename T> 00515 class celGaussActivationFunc : public celNNActivationFunc 00516 { 00517 public: 00518 virtual void Function(celData &data) 00519 { 00520 const T &val = GetFrom<T>(data); 00521 data.Set((T) exp((double) -(val * val))); 00522 } 00523 virtual celDataType GetDataType() { return DataType<T>(); } 00524 virtual ~celGaussActivationFunc() {} 00525 }; 00526 00538 template <typename T> 00539 class celSinActivationFunc : public celNNActivationFunc 00540 { 00541 public: 00542 virtual void Function(celData &data) 00543 { 00544 const T &val = GetFrom<T>(data); 00545 data.Set((T) sin((double) val)); 00546 } 00547 virtual celDataType GetDataType() { return DataType<T>(); } 00548 virtual ~celSinActivationFunc() {} 00549 }; 00550 00562 template <typename T> 00563 class celCosActivationFunc : public celNNActivationFunc 00564 { 00565 public: 00566 virtual void Function(celData &data) 00567 { 00568 const T &val = GetFrom<T>(data); 00569 data.Set((T) cos((double) val)); 00570 } 00571 virtual celDataType GetDataType() { return DataType<T>(); } 00572 virtual ~celCosActivationFunc() {} 00573 }; 00574 00586 template <typename T> 00587 class celElliottActivationFunc : public celNNActivationFunc 00588 { 00589 public: 00590 virtual void Function(celData &data) 00591 { 00592 const T &val = GetFrom<T>(data); 00593 data.Set(val / (1 + (T)fabs((double) val))); 00594 } 00595 virtual celDataType GetDataType() { return DataType<T>(); } 00596 virtual ~celElliottActivationFunc() {} 00597 }; 00598 00610 template <typename T> 00611 class celSigActivationFunc : public celNNActivationFunc 00612 { 00613 public: 00614 virtual void Function(celData &data) 00615 { 00616 const T &val = GetFrom<T>(data); 00617 data.Set(T(1) / (T)(1 + exp((double) -val))); 00618 } 00619 virtual celDataType GetDataType() { return DataType<T>(); } 00620 virtual ~celSigActivationFunc() {} 00621 }; 00622 00623 #endif // __CEL_PF_NEURALNET__ 00624
Generated for CEL: Crystal Entity Layer 2.0 by doxygen 1.6.1
