重庆分公司,新征程启航
为企业提供网站建设、域名注册、服务器等服务
这篇文章主要为大家展示了如何使用Keras预训练模型ResNet50进行图像分类,内容简而易懂,希望大家可以学习一下,学习完之后肯定会有收获的,下面让小编带大家一起来看看吧。
成都创新互联公司网站建设由有经验的网站设计师、开发人员和项目经理组成的专业建站团队,负责网站视觉设计、用户体验优化、交互设计和前端开发等方面的工作,以确保网站外观精美、成都网站设计、做网站易于使用并且具有良好的响应性。Keras提供了一些用ImageNet训练过的模型:Xception,VGG16,VGG19,ResNet50,InceptionV3。在使用这些模型的时候,有一个参数include_top表示是否包含模型顶部的全连接层,如果包含,则可以将图像分为ImageNet中的1000类,如果不包含,则可以利用这些参数来做一些定制的事情。
在运行时自动下载有可能会失败,需要去网站中手动下载,放在“~/.keras/models/”中,使用WinPython则在“settings/.keras/models/”中。
修正:表示当前是训练模式还是测试模式的参数K.learning_phase()文中表述和使用有误,在该函数说明中可以看到:
The learning phase flag is a bool tensor (0 = test, 1 = train),所以0是测试模式,1是训练模式,部分网络结构下两者有差别。
这里使用ResNet50预训练模型,对Caltech201数据集进行图像分类。只有CPU,运行较慢,但是在训练集固定的情况下,较慢的过程只需要运行一次。
该预训练模型的中文文档介绍在http://keras-cn.readthedocs.io/en/latest/other/application/#resnet50。
我使用的版本:
1.Ubuntu 16.04.3
2.Python 2.7
3.Keras 2.0.8
4.Tensoflow 1.3.0
5.Numpy 1.13.1
6.python-opencv 2.4.9.1+dfsg-1.5ubuntu1
7.h6py 2.7.0
从文件夹中提取图像数据的方式:
函数:
def eachFile(filepath): #将目录内的文件名放入列表中 pathDir = os.listdir(filepath) out = [] for allDir in pathDir: child = allDir.decode('gbk') # .decode('gbk')是解决中文显示乱码问题 out.append(child) return out def get_data(data_name,train_left=0.0,train_right=0.7,train_all=0.7,resize=True,data_format=None,t=''): #从文件夹中获取图像数据 file_name = os.path.join(pic_dir_out,data_name+t+'_'+str(train_left)+'_'+str(train_right)+'_'+str(Width)+"X"+str(Height)+".h6") print file_name if os.path.exists(file_name): #判断之前是否有存到文件中 f = h6py.File(file_name,'r') if t=='train': X_train = f['X_train'][:] y_train = f['y_train'][:] f.close() return (X_train, y_train) elif t=='test': X_test = f['X_test'][:] y_test = f['y_test'][:] f.close() return (X_test, y_test) else: return data_format = conv_utils.normalize_data_format(data_format) pic_dir_set = eachFile(pic_dir_data) X_train = [] y_train = [] X_test = [] y_test = [] label = 0 for pic_dir in pic_dir_set: print pic_dir_data+pic_dir if not os.path.isdir(os.path.join(pic_dir_data,pic_dir)): continue pic_set = eachFile(os.path.join(pic_dir_data,pic_dir)) pic_index = 0 train_count = int(len(pic_set)*train_all) train_l = int(len(pic_set)*train_left) train_r = int(len(pic_set)*train_right) for pic_name in pic_set: if not os.path.isfile(os.path.join(pic_dir_data,pic_dir,pic_name)): continue img = cv2.imread(os.path.join(pic_dir_data,pic_dir,pic_name)) if img is None: continue if (resize): img = cv2.resize(img,(Width,Height)) img = img.reshape(-1,Width,Height,3) if (pic_index < train_count): if t=='train': if (pic_index >= train_l and pic_index < train_r): X_train.append(img) y_train.append(label) else: if t=='test': X_test.append(img) y_test.append(label) pic_index += 1 if len(pic_set) <> 0: label += 1 f = h6py.File(file_name,'w') if t=='train': X_train = np.concatenate(X_train,axis=0) y_train = np.array(y_train) f.create_dataset('X_train', data = X_train) f.create_dataset('y_train', data = y_train) f.close() return (X_train, y_train) elif t=='test': X_test = np.concatenate(X_test,axis=0) y_test = np.array(y_test) f.create_dataset('X_test', data = X_test) f.create_dataset('y_test', data = y_test) f.close() return (X_test, y_test) else: return