__init__.py

from importlib import import_module
from dataloader import MSDataLoader      #
from torch.utils.data.dataloader import default_collate

class Data:
    def __init__(self, args, model):
        kwargs = {}
        if not args.cpu:
            kwargs['collate_fn'] = default_collate
            kwargs['pin_memory'] = True
        else:
            kwargs['collate_fn'] = default_collate
            kwargs['pin_memory'] = False

        self.loader_train = None
        if not args.test_only:

            if args.data_train.lower() != 'rrl':
                module_train = import_module('data.' + args.data_train.lower())
                trainset = getattr(module_train, args.data_train)(args)
            else: 
                module_train = import_module('data.' + args.rrl_data.lower())
                trainclass = getattr(module_train, args.rrl_data)
                
                module_train = import_module('data.rrl')
                trainset = getattr(module_train, 'RRL')(trainclass, args, model)

            self.loader_train = MSDataLoader(
                    args,
                    trainset,
                    batch_size=args.batch_size,
                    shuffle=True,
                    **kwargs
                )

        if args.data_test in ['Set5', 'Set14', 'B100', 'Urban100']:
            if not args.benchmark_noise:
                module_test = import_module('data.benchmark')
                testset = getattr(module_test, 'Benchmark')(args, train=False)
            else:
                module_test = import_module('data.benchmark_noise')
                testset = getattr(module_test, 'BenchmarkNoise')(
                    args,
                    train=False
                )

        else:
            if args.data_test.lower() != 'rrl': 
                module_test = import_module('data.' +  args.data_test.lower())
                testset = getattr(module_test, args.data_test)(args, train=False)
            else: 
                module_test = import_module('data.' + args.rrl_data.lower())
                testclass = getattr(module_test, args.rrl_data)

                module_test = import_module('data.rrl')
                testset = getattr(module_test, 'RRL')(testclass, args, model, False)

        self.loader_test = MSDataLoader(
            args,
            testset,
            batch_size=1,
            shuffle=False,
            **kwargs
        )
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容