TensorFlow产生3类多元高斯分布样本集(onehot编码)的generate()函数

#!/usr/bin/python
#  -*-  coding:UTF-8   -*-

__author__="David Chow"

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from sklearn.utils import shuffle

def generate(sample_size,num_classes,mean,cov,diff,regression):

samples_per_class=int(sample_size/num_classes)
X0=np.random.multivariate_normal(mean,cov,samples_per_class)
Y0=np.zeros(samples_per_class)
for ci,d in enumerate(diff):
X1=np.random.multivariate_normal(mean+d,cov,samples_per_class)
Y1=(ci+1)*np.ones(samples_per_class)
X0=np.concatenate((X0,X1))
Y0=np.concatenate((Y0,Y1))

print(np.shape(Y0))


Y=np.zeros((samples_per_class*num_classes,3))
print(np.shape(Y))
if regression==False:           #onehot(独热码)标签  即0变成[1,0,0],1变成[0,1,0],0变成[0,1,0]
Y[Y0==0,0]=1
Y[Y0==1,1]=1
Y[Y0==2,2]=1


X,Y=shuffle(X0,Y)

return X,Y


#产生数据

input_dim = 2 
np.random.seed(10)
num_classes=3
mean=np.random.randn(input_dim )
cov=np.eye(input_dim )
X, Y=sf.generate(1000,num_classes,mean,cov,[[3.0,3.0],[3.0,0]],False)
aa=[np.argmax(i) for i in Y]
colors=['r' if l==0 else 'b' if l ==1 else 'y' for l in aa[:]]
plt.scatter(X[:,0],X[:,1],c=colors)
plt.show()

输出结果如下图所示:

TensorFlow产生3类多元高斯分布样本集(onehot编码)的generate()函数

TensorFlow产生3类多元高斯分布样本集(onehot编码)的generate()函数