from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScaler
    if opt.half:
        scaler = GradScaler()
        for i, data in enumerate(train_loader, 0):
            data_time.update(time.time() - start_time)
            # Hnet.change(False)
            Hnet.zero_grad()
            Rnet.zero_grad()
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                all_pics = data  # allpics contains cover images and secret images
                this_batch_size = int(all_pics.size()[0])  # get true batch size of this step

                # first half of images will become cover images, the rest are treated as secret images
                cover_img = all_pics[0:this_batch_size, :, :, :]  # batchsize,3,256,256
                secret_img,target=sample_secret_img(this_batch_size)

                # secret_img = all_pics[this_batch_size:this_batch_size * 2, :, :, :]

                cover_img = cover_img.to(device)
                secret_img = secret_img.to(device)

                with torch.no_grad():
                    cover_img = Variable(cover_img)
                    secret_imgv = Variable(secret_img)
                cover_imgv = JPEG(cover_img)
                # print(cover_imgv.shape, secret_imgv.shape)
                container = Hnet(cover_imgv, secret_imgv)  # put concat_image into H-net and get container image

                errH = criterion(container, cover_imgv)  # loss between cover and container
                Hlosses.update(errH.item(), this_batch_size)

                compress_img=JPEG(container)

                rev_secret_img = Rnet(compress_img)  # put concatenated image into R-net and get revealed secret image
                # print(rev_secret_img.shape, secret_imgv.shape)
                errR = criterion(rev_secret_img, secret_imgv)  # loss between secret image and revealed secret image
                Rlosses.update(errR.item(), this_batch_size)

                betaerrR_secret = opt.beta * errR
                err_sum = errH + betaerrR_secret
                SumLosses.update(err_sum.item(), this_batch_size)

            # err_sum.backward()
            scaler.scale(err_sum).backward()
            scaler.step(optimizerH)
            scaler.step(optimizerR)
            # optimizerH.step()
            # optimizerR.step()
            # Updates the scale for next iteration.
            scaler.update()
            batch_time.update(time.time() - start_time)
            start_time = time.time()

            log = '[%d/%d][%d/%d]\\tLoss_H: %.4f Loss_R: %.4f Loss_sum: %.4f \\tdatatime: %.4f \\tbatchtime: %.4f' % (
                epoch, opt.niter, i, len(train_loader),
                Hlosses.val, Rlosses.val, SumLosses.val, data_time.val, batch_time.val)

            if i % opt.logFrequency == 0:
                print_log(log, logPath)
            else:
                print_log(log, logPath, console=False)

            # genereate a picture every resultPicFrequency steps
            if epoch % 1 == 0 and i % opt.resultPicFrequency == 0:
                save_result_pic(this_batch_size,
                                cover_img, container.data,compress_img.data,
                                secret_img, rev_secret_img.data,
                                epoch, i, opt.trainpics)
            # break

    else:
        for i, data in enumerate(train_loader, 0):
            data_time.update(time.time() - start_time)
            # Hnet.change(False)
            Hnet.zero_grad()
            Rnet.zero_grad()
            all_pics = data  # allpics contains cover images and secret images
            this_batch_size = int(all_pics.size()[0])  # get true batch size of this step

            # first half of images will become cover images, the rest are treated as secret images
            cover_img = all_pics[0:this_batch_size, :, :, :]  # batchsize,3,256,256
            secret_img,target=sample_secret_img(this_batch_size)

            # secret_img = all_pics[this_batch_size:this_batch_size * 2, :, :, :]

            cover_img = cover_img.to(device)
            secret_img = secret_img.to(device)

            with torch.no_grad():
                cover_img = Variable(cover_img)
                secret_imgv = Variable(secret_img)
            cover_imgv = JPEG(cover_img)
            # print(cover_imgv.shape, secret_imgv.shape)
            container = Hnet(cover_imgv, secret_imgv)  # put concat_image into H-net and get container image

            errH = criterion(container, cover_imgv)  # loss between cover and container
            Hlosses.update(errH.item(), this_batch_size)

            compress_img=JPEG(container)

            rev_secret_img = Rnet(compress_img)  # put concatenated image into R-net and get revealed secret image
            # print(rev_secret_img.shape, secret_imgv.shape)
            errR = criterion(rev_secret_img, secret_imgv)  # loss between secret image and revealed secret image
            Rlosses.update(errR.item(), this_batch_size)

            betaerrR_secret = opt.beta * errR
            err_sum = errH + betaerrR_secret
            SumLosses.update(err_sum.item(), this_batch_size)

            err_sum.backward()

            optimizerH.step()
            optimizerR.step()
            # Updates the scale for next iteration.
            batch_time.update(time.time() - start_time)
            start_time = time.time()

            log = '[%d/%d][%d/%d]\\tLoss_H: %.4f Loss_R: %.4f Loss_sum: %.4f \\tdatatime: %.4f \\tbatchtime: %.4f' % (
                epoch, opt.niter, i, len(train_loader),
                Hlosses.val, Rlosses.val, SumLosses.val, data_time.val, batch_time.val)

            if i % opt.logFrequency == 0:
                print_log(log, logPath)
            else:
                print_log(log, logPath, console=False)

            # genereate a picture every resultPicFrequency steps
            if epoch % 1 == 0 and i % opt.resultPicFrequency == 0:
                save_result_pic(this_batch_size,
                                cover_img, container.data,compress_img.data,
                                secret_img, rev_secret_img.data,
                                epoch, i, opt.trainpics)