// Copyright (C) 2002 Ronan Collobert (collober@iro.umontreal.ca)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#ifndef GRADIENT_MACHINE_INC
#define GRADIENT_MACHINE_INC

#include "Machine.h"

namespace Torch {

/** Gradient machine: machine which can
    be trained with a gradient descent.

    @author Ronan Collobert (collober@iro.umontreal.ca)
*/
class GradientMachine : public Machine
{
  public:

    /// #true# if #freeMemory()# has been already called.
    bool is_free;

    /** Contains all parameters which will be
        updated with the gradient descent.
        Almost all machines will have only one
        node for this list.
        Note that the integer for each node
        is the number of parameters contained
        in the pointer of this node.
    */
    List *params;

    /** Contains the derivatives for all parameters.
        Warning: #params# and #der_params#
        must have the same structure.
        (same number of nodes, same lenghts)
    */
    List *der_params;

    /** Contains the number of parameters.
        It's in fact the sum of all node size
        contained in #params# or #der_params#.
    */
    int n_params;
    
    /// Contains the derivative with respect to the inputs.
    real *beta;

    //-----

    ///
    GradientMachine();

    /// It only calls #allocateMemory()#
    virtual void init();

    /** Return the size of #params#.
        Note: it's the sum of the size in each node
        of #params#.
    */
    virtual int numberOfParams() = 0;

    /** This function is called before each
        training iteration.
        By default, do nothing.
    */
    virtual void iterInitialize();

    /** Given the #inputs# and the derivatives #alpha# with
        respect to the outputs, update #beta# and #der_params#.
    */
    virtual void backward(List *inputs, real *alpha) = 0;

    /** Allocate memory.
        By default, given #n_inputs#, #n_outputs# and #n_params#,
        allocate #beta#, #outputs#, #params# and #der_params#.
    */
    virtual void allocateMemory();

    /** Free the memory allocated with #allocateMemory()#
        \emph{Do nothing} if #is_free# is #true#.
        Set #is_free# to #true#.
        This function \emph{have to set} #n_params#.
        All classes which redefine this function should call
        it in the destructor.
    */
    virtual void freeMemory();

    /// By default, load the #params#.
    virtual void loadFILE(FILE *file);

    /// By default, save the #params#.
    virtual void saveFILE(FILE *file);

    //-----

    virtual void reset();
    virtual ~GradientMachine();
};


}

#endif

