采用torchvision.datasets.ImageFolder這個(gè)接口來讀取圖像數(shù)據(jù),該接口默認(rèn)你的訓(xùn)練數(shù)據(jù)是按照一個(gè)類別存放在一個(gè)文件夾下。但是有些情況下你的圖像數(shù)據(jù)不是這樣維護(hù)的,比如一個(gè)文件夾下面各個(gè)類別的圖像數(shù)據(jù)都有,同時(shí)用一個(gè)對(duì)應(yīng)的標(biāo)簽文件,比如txt文件來維護(hù)圖像和標(biāo)簽的對(duì)應(yīng)關(guān)系,在這種情況下就不能用torchvision.datasets.ImageFolder來讀取數(shù)據(jù)了,需要自定義一個(gè)數(shù)據(jù)讀取接口。
繼承的類是torch.utils.data.Dataset,主要包含三個(gè)方法:初始化init,獲取圖像getitem,數(shù)據(jù)集數(shù)量 len。init方法中先通過find_classes函數(shù)得到分類的類別名(classes)和類別名與數(shù)字類別的映射關(guān)系字典(class_to_idx)。然后通過make_dataset函數(shù)得到imags,這個(gè)imags是一個(gè)列表,其中每個(gè)值是一個(gè)tuple,每個(gè)tuple包含兩個(gè)元素:圖像路徑和標(biāo)簽。剩下的就是一些賦值操作了。在getitem方法中最重要的就是 img = self.loader(path)這行,表示數(shù)據(jù)讀取,可以從init方法中看出self.loader采用的是default_loader,這個(gè)default_loader的核心就是用python的PIL庫(kù)的Image模塊來讀取圖像數(shù)據(jù)。
class ImageFolder(data.Dataset):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
classes, class_to_idx = find_classes(root)
imgs = make_dataset(root, class_to_idx)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.classes = classes
self.class_to_idx = class_to_idx
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
path, target = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.imgs)
稍微看下default_loader函數(shù),該函數(shù)主要分兩種情況調(diào)用兩個(gè)函數(shù),一般采用pil_loader函數(shù)。
def pil_loader(path):
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
看懂了ImageFolder這個(gè)類,就可以自定義一個(gè)你自己的數(shù)據(jù)讀取接口了。
首先在PyTorch中和數(shù)據(jù)讀取相關(guān)的類基本都要繼承一個(gè)基類:torch.utils.data.Dataset。然后再改寫其中的init、len、getitem等方法即可。
下面假設(shè)img_path是你的圖像文件夾,該文件夾下面放了所有圖像數(shù)據(jù)(包括訓(xùn)練和測(cè)試),然后txt_path下面放了train.txt和val.txt兩個(gè)文件,txt文件中每行都是圖像路徑,tab鍵,標(biāo)簽。所以下面代碼的init方法中self.img_name和self.img_label的讀取方式就跟你數(shù)據(jù)的存放方式有關(guān),你可以根據(jù)你實(shí)際數(shù)據(jù)的維護(hù)方式做調(diào)整。getitem方法沒有做太大改動(dòng),依然采用default_loader方法來讀取圖像。最后在Transform中將每張圖像都封裝成Tensor。
class customData(Dataset):
def __init__(self, img_path, txt_path, dataset = '', data_transforms=None, loader = default_loader):
with open(txt_path) as input_file:
lines = input_file.readlines()
self.img_name = [os.path.join(img_path, line.strip().split('\t')[0]) for line in lines]
self.img_label = [int(line.strip().split('\t')[-1]) for line in lines]
self.data_transforms = data_transforms
self.dataset = dataset
self.loader = loader
def __len__(self):
return len(self.img_name)
def __getitem__(self, item):
img_name = self.img_name[item]
label = self.img_label[item]
img = self.loader(img_name)
if self.data_transforms is not None:
try:
img = self.data_transforms[self.dataset](img)
except:
print("Cannot transform image: {}".format(img_name))
return img, label
定義好了數(shù)據(jù)讀取接口后,怎么用呢?
在代碼中可以這樣調(diào)用。
image_datasets = {x: customData(img_path='/ImagePath',
txt_path=('/TxtFile/' + x + '.txt'),
data_transforms=data_transforms,
dataset=x) for x in ['train', 'val']}
這樣返回的image_datasets就和用torchvision.datasets.ImageFolder類返回的數(shù)據(jù)類型一樣.
有了image_datasets,然后依然用torch.utils.data.DataLoader類來做進(jìn)一步封裝,將這個(gè)batch的圖像數(shù)據(jù)和標(biāo)簽都分別封裝成Tensor。
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=batch_size,
shuffle=True) for x in ['train', 'val']}
另外,每次迭代生成的模型要怎么保存呢?非常簡(jiǎn)單,那就是用torch.save。輸入就是你的模型和要保存的路徑及模型名稱,如果這個(gè)output文件夾沒有,可以手動(dòng)新建一個(gè)或者在代碼里面新建。
torch.save(model, 'output/resnet_epoch{}.pkl'.format(epoch))
最后,關(guān)于多GPU的使用,PyTorch支持多GPU訓(xùn)練模型,假設(shè)你的網(wǎng)絡(luò)是model,那么只需要下面一行代碼(調(diào)用 torch.nn.DataParallel接口)就可以讓后續(xù)的模型訓(xùn)練在0和1兩塊GPU上訓(xùn)練,加快訓(xùn)練速度。
model = torch.nn.DataParallel(model, device_ids=[0,1])
參考
PyTorch使用及源碼解讀
Finetuning Torchvision Models
Welcome to PyTorch Tutorials
PyTorch學(xué)習(xí)之路(level2)——自定義數(shù)據(jù)讀取
https://github.com/miraclewkf/ImageClassification-PyTorch/blob/master/level2/train_customData.py