public class RBM extends Object
Modifier and Type | Field and Description |
---|---|
protected double |
COST |
protected Jama.Matrix |
dW_ |
protected double |
LEARNING_RATE |
protected int |
m_E |
protected int |
m_H |
protected Random |
m_R |
protected double |
MOMENTUM |
protected Jama.Matrix |
W |
Constructor and Description |
---|
RBM()
RBM - Create an RBM with default options.
|
RBM(String[] options)
RBM - Create an RBM with 'options' (using WEKA-style option processing).
|
Modifier and Type | Method and Description |
---|---|
double |
calculateError(Jama.Matrix X)
Calculate the Error right now.
|
Jama.Matrix |
epoch(Jama.Matrix X_0)
Epoch - Run X through one epcho of CD of the RBM.
|
int |
getE() |
int |
getH() |
double |
getLearningRate() |
double |
getMomentum() |
String[] |
getOptions()
GetOptions - WEKA-style option processing.
|
Jama.Matrix |
getW() |
Jama.Matrix[] |
getWs() |
void |
initWeights(int d)
Initialize W, and make _dW (for momentum) of the same dimensions.
|
static void |
main(String[] argv)
Main - do some test routines.
|
protected Jama.Matrix |
makeW(int d,
int h) |
static Jama.Matrix |
makeW(int d,
int h,
Random r)
Make W matrix of dimensions d+1 and h+1 (+1 for biases).
|
double[] |
prob_x(double[] z_)
Visible Activation Probability - returns P(x|z) where p(x[j]==1|z) for each j-th element.
|
Jama.Matrix |
prob_X(Jama.Matrix Z)
Visible Activation Probability - returns P(X|Z).
|
double[] |
prob_z(double[] x_)
Hidden Activation Probability - returns P(z|x) where p(z[i]==1|x) for each element.
|
double[][] |
prob_Z(double[][] X_)
Hidden Activation Probability - returns P(Z|X).
|
Jama.Matrix |
prob_Z(Jama.Matrix X)
Hidden Activation Probability - returns P(Z|X).
|
double[][] |
propUp(double[][] X_)
Hidden Activation Value.
|
Jama.Matrix |
sample_epoch(Jama.Matrix X_0) |
double[] |
sample_x(double[] z_)
Sample Visible - returns x[j] ~ p(x[j]==1|z) for each j-th element.
|
Jama.Matrix |
sample_X(Jama.Matrix Z)
Sample Visible - returns X ~ P(X|Z).
|
double[] |
sample_z(double[] x_)
Sample Hidden Value - returns z[i] ~ p(z[i]==1|x) for each i-th element.
|
Jama.Matrix |
sample_Z(Jama.Matrix X)
Sample Hidden Value - returns Z ~ P(Z|X).
|
void |
setE(int n)
SetE - set the number of epochs (if n is negative, it means max epochs).
|
void |
setH(int h) |
void |
setLearningRate(double r) |
void |
setMomentum(double m) |
void |
setOptions(String[] options)
Set Options - WEKA-style option processing.
|
void |
setSeed(int seed) |
String |
toString()
ToString - return a String representation of the weight Matrix defining this RBM.
|
double |
train(double[][] X_)
Train - Setup and train the RBM on X, over m_E epochs.
|
double |
train(double[][] X_,
int batchSize)
Train - Setup and batch-train the RBM on X.
|
double |
train(double[][] X_,
int batchSize,
Random r)
Train - Setup and batch-train the RBM on X, with some random sampling involved.
|
void |
update(double[] x_)
Update - On raw data (with no bias column)
|
void |
update(double[][] X_)
Update - On raw data (with no bias column)
|
void |
update(double[] x_,
double s)
Update - On raw data (with no bias column)
|
void |
update(Jama.Matrix X)
Update - Carry out one epoch of CD, update W.
|
void |
update(Jama.Matrix X,
double s)
Update - Carry out one epoch of CD, update W.
|
protected double LEARNING_RATE
protected double MOMENTUM
protected double COST
protected int m_E
protected int m_H
protected Jama.Matrix W
protected Jama.Matrix dW_
protected Random m_R
public void setOptions(String[] options) throws Exception
Exception
public String[] getOptions() throws Exception
Exception
public double[] prob_z(double[] x_)
x_
- x (without bias)public double[][] prob_Z(double[][] X_)
X_
- X (without bias)public Jama.Matrix prob_Z(Jama.Matrix X)
X
- X (bias included)public double[][] propUp(double[][] X_)
X_
- X (without bias)public Jama.Matrix sample_Z(Jama.Matrix X)
X
- X (bias included)public double[] sample_z(double[] x_)
x_
- x (without bias)public double[] sample_x(double[] z_)
z_
- z (without bias)public Jama.Matrix sample_X(Jama.Matrix Z)
Z
- Z (bias included)public double[] prob_x(double[] z_)
z_
- z (without bias)public Jama.Matrix prob_X(Jama.Matrix Z)
Z
- z (bias included)public static Jama.Matrix makeW(int d, int h, Random r)
d
- number of rows (visible units)h
- number of columns (hidden units)r
- for getting random rumbersprotected Jama.Matrix makeW(int d, int h)
public void initWeights(int d)
d
- number of visible unitspublic void update(Jama.Matrix X)
X
- Xpublic void update(Jama.Matrix X, double s)
X
- Xs
- multiply the gradient by this scalarpublic void update(double[][] X_)
X_
- raw double[][] data (with no bias column)public void update(double[] x_)
x_
- raw double[] data (with no bias column)public void update(double[] x_, double s)
x_
- raw double[] data (with no bias column)s
- multiply the gradient by this scalarpublic double train(double[][] X_) throws Exception
X_
- XException
public double train(double[][] X_, int batchSize) throws Exception
X_
- XbatchSize
- the batch sizeException
public double train(double[][] X_, int batchSize, Random r) throws Exception
X_
- XbatchSize
- the batch sizer
- the randomnessException
public double calculateError(Jama.Matrix X)
X
- Xpublic Jama.Matrix epoch(Jama.Matrix X_0)
X_0
- The input matrix (includes bias column).public Jama.Matrix sample_epoch(Jama.Matrix X_0)
public void setH(int h)
public int getH()
public void setE(int n)
public int getE()
public void setLearningRate(double r)
public double getLearningRate()
public void setMomentum(double m)
public double getMomentum()
public void setSeed(int seed)
public Jama.Matrix[] getWs()
public Jama.Matrix getW()
public String toString()
Copyright © 2017. All Rights Reserved.